MLX in C++ example (#1736)

* MLX in C++ example

* nits

* fix docs
This commit is contained in:
Awni Hannun 2025-01-02 19:09:04 -08:00 committed by GitHub
parent 8544b42007
commit c9d30aa6ac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 242 additions and 31 deletions

121
docs/src/dev/mlx_in_cpp.rst Normal file
View File

@ -0,0 +1,121 @@
.. _mlx_in_cpp:
Using MLX in C++
================
You can use MLX in a C++ project with CMake.
.. note::
This guide is based one the following `example using MLX in C++
<https://github.com/ml-explore/mlx/tree/main/examples/cmake_project>`_
First install MLX:
.. code-block:: bash
pip install -U mlx
You can also install the MLX Python package from source or just the C++
library. For more information see the :ref:`documentation on installing MLX
<build_and_install>`.
Next make an example program in ``example.cpp``:
.. code-block:: C++
#include <iostream>
#include "mlx/mlx.h"
namespace mx = mlx::core;
int main() {
auto x = mx::array({1, 2, 3});
auto y = mx::array({1, 2, 3});
std::cout << x + y << std::endl;
return 0;
}
The next step is to setup a CMake file in ``CMakeLists.txt``:
.. code-block:: cmake
cmake_minimum_required(VERSION 3.27)
project(example LANGUAGES CXX)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
Depending on how you installed MLX, you may need to tell CMake where to
find it.
If you installed MLX with Python, then add the following to the CMake file:
.. code-block:: cmake
find_package(
Python 3.9
COMPONENTS Interpreter Development.Module
REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE MLX_ROOT)
If you installed the MLX C++ package to a system path, then CMake should be
able to find it. If you installed it to a non-standard location or CMake can't
find MLX then set ``MLX_ROOT`` to the location where MLX is installed:
.. code-block:: cmake
set(MLX_ROOT "/path/to/mlx/")
Next, instruct CMake to find MLX:
.. code-block:: cmake
find_package(MLX CONFIG REQUIRED)
Finally, add the ``example.cpp`` program as an executable and link MLX.
.. code-block:: cmake
add_executable(example example.cpp)
target_link_libraries(example PRIVATE mlx)
You can build the example with:
.. code-block:: bash
cmake -B build -DCMAKE_BUILD_TYPE=Release
cmake --build build
And run it with:
.. code-block:: bash
./build/example
Note ``find_package(MLX CONFIG REQUIRED)`` sets the following variables:
.. list-table:: Package Variables
:widths: 20 20
:header-rows: 1
* - Variable
- Description
* - MLX_FOUND
- ``True`` if MLX is found
* - MLX_INCLUDE_DIRS
- Include directory
* - MLX_LIBRARIES
- Libraries to link against
* - MLX_CXX_FLAGS
- Additional compiler flags
* - MLX_BUILD_ACCELERATE
- ``True`` if MLX was built with Accelerate
* - MLX_BUILD_METAL
- ``True`` if MLX was built with Metal

View File

@ -87,3 +87,4 @@ are the CPU and GPU.
dev/extensions
dev/metal_debugger
dev/custom_metal_kernels
dev/mlx_in_cpp

View File

@ -1,3 +1,5 @@
.. _build_and_install:
Build and Install
=================

View File

@ -89,6 +89,7 @@ Operations
isneginf
isposinf
issubdtype
kron
left_shift
less
less_equal

View File

@ -0,0 +1,22 @@
cmake_minimum_required(VERSION 3.27)
project(example LANGUAGES CXX)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
# Comment the following two commands only the MLX C++ library is installed and
# set(MLX_ROOT "/path/to/mlx") directly if needed.
find_package(
Python 3.9
COMPONENTS Interpreter Development.Module
REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE MLX_ROOT)
find_package(MLX CONFIG REQUIRED)
add_executable(example example.cpp)
target_link_libraries(example PRIVATE mlx)

View File

@ -0,0 +1,26 @@
## Build and Run
Install MLX with Python:
```bash
pip install mlx>=0.22
```
Build the C++ example:
```bash
cmake -B build -DCMAKE_BUILD_TYPE=Release
cmake --build build
```
Run the C++ example:
```
./build/example
```
which should output:
```
array([2, 4, 6], dtype=int32)
```

View File

@ -0,0 +1,14 @@
// Copyright © 2024 Apple Inc.
#include <iostream>
#include "mlx/mlx.h"
namespace mx = mlx::core;
int main() {
auto x = mx::array({1, 2, 3});
auto y = mx::array({1, 2, 3});
std::cout << x + y << std::endl;
return 0;
}

View File

@ -2,20 +2,15 @@ cmake_minimum_required(VERSION 3.27)
project(import_mlx LANGUAGES CXX)
# ----------------------------- Setup -----------------------------
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
# ----------------------------- Dependencies -----------------------------
find_package(
Python 3.9
COMPONENTS Interpreter Development.Module
REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m pip show mlx
COMMAND grep location
COMMAND awk "{print $4 \"/mlx\"}"
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE MLX_ROOT)
find_package(MLX CONFIG REQUIRED)

View File

@ -3,18 +3,18 @@
#include <mlx/mlx.h>
#include <iostream>
using namespace mlx::core;
namespace mx = mlx::core;
int main() {
int batch_size = 8;
int input_dim = 32;
// Make the input
random::seed(42);
auto example_x = random::uniform({batch_size, input_dim});
mx::random::seed(42);
auto example_x = mx::random::uniform({batch_size, input_dim});
// Import the function
auto forward = import_function("eval_mlp.mlxfn");
auto forward = mx::import_function("eval_mlp.mlxfn");
// Call the imported function
auto out = forward({example_x})[0];

View File

@ -3,22 +3,22 @@
#include <mlx/mlx.h>
#include <iostream>
using namespace mlx::core;
namespace mx = mlx::core;
int main() {
int batch_size = 8;
int input_dim = 32;
int output_dim = 10;
auto state = import_function("init_mlp.mlxfn")({});
auto state = mx::import_function("init_mlp.mlxfn")({});
// Make the input
random::seed(42);
auto example_X = random::normal({batch_size, input_dim});
auto example_y = random::randint(0, output_dim, {batch_size});
mx::random::seed(42);
auto example_X = mx::random::normal({batch_size, input_dim});
auto example_y = mx::random::randint(0, output_dim, {batch_size});
// Import the function
auto step = import_function("train_mlp.mlxfn");
auto step = mx::import_function("train_mlp.mlxfn");
// Call the imported function
for (int it = 0; it < 100; ++it) {

View File

@ -914,7 +914,7 @@ inline array gather(
return gather(a, {indices}, std::vector<int>{axis}, slice_sizes, s);
}
/** Returns Kronecker Producct given two input arrays. */
/** Compute the Kronecker product of two arrays. */
array kron(const array& a, const array& b, StreamOrDevice s = {});
/** Take array slices at the given indices of the specified axis. */

27
python/mlx/__main__.py Normal file
View File

@ -0,0 +1,27 @@
import argparse
def main() -> None:
from mlx.core import __version__
parser = argparse.ArgumentParser()
parser.add_argument(
"--version",
action="version",
version=__version__,
help="Print the version number.",
)
parser.add_argument(
"--cmake-dir",
action="store_true",
help="Print the path to the MLX CMake module directory.",
)
args = parser.parse_args()
if args.cmake_dir:
from pathlib import Path
print(Path(__file__).parent)
if __name__ == "__main__":
main()

View File

@ -1468,24 +1468,26 @@ void init_ops(nb::module_& m) {
nb::sig(
"def kron(a: array, b: array, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Compute the Kronecker product of two arrays `a` and `b`.
Compute the Kronecker product of two arrays ``a`` and ``b``.
Args:
a (array): The first input array
b (array): The second input array
stream (Union[None, Stream, Device], optional): Optional stream or device for execution.
Default is `None`.
a (array): The first input array.
b (array): The second input array.
stream (Union[None, Stream, Device], optional): Optional stream or
device for execution. Default: ``None``.
Returns:
array: The Kronecker product of `a` and `b`.
array: The Kronecker product of ``a`` and ``b``.
Examples:
>>> import mlx
>>> a = mlx.array([[1, 2], [3, 4]])
>>> b = mlx.array([[0, 5], [6, 7]])
>>> result = mlx.kron(a, b)
>>> a = mx.array([[1, 2], [3, 4]])
>>> b = mx.array([[0, 5], [6, 7]])
>>> result = mx.kron(a, b)
>>> print(result)
[[ 0 5 0 10]
[ 6 7 12 14]
[ 0 15 0 20]
[18 21 24 28]]
array([[0, 5, 0, 10],
[6, 7, 12, 14],
[0, 15, 0, 20],
[18, 21, 24, 28]], dtype=int32)
)pbdoc");
m.def(
"take",