MLX in C++ example (#1736)

* MLX in C++ example

* nits

* fix docs
This commit is contained in:
Awni Hannun 2025-01-02 19:09:04 -08:00 committed by GitHub
parent 8544b42007
commit c9d30aa6ac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 242 additions and 31 deletions

121
docs/src/dev/mlx_in_cpp.rst Normal file
View File

@ -0,0 +1,121 @@
.. _mlx_in_cpp:
Using MLX in C++
================
You can use MLX in a C++ project with CMake.
.. note::
This guide is based one the following `example using MLX in C++
<https://github.com/ml-explore/mlx/tree/main/examples/cmake_project>`_
First install MLX:
.. code-block:: bash
pip install -U mlx
You can also install the MLX Python package from source or just the C++
library. For more information see the :ref:`documentation on installing MLX
<build_and_install>`.
Next make an example program in ``example.cpp``:
.. code-block:: C++
#include <iostream>
#include "mlx/mlx.h"
namespace mx = mlx::core;
int main() {
auto x = mx::array({1, 2, 3});
auto y = mx::array({1, 2, 3});
std::cout << x + y << std::endl;
return 0;
}
The next step is to setup a CMake file in ``CMakeLists.txt``:
.. code-block:: cmake
cmake_minimum_required(VERSION 3.27)
project(example LANGUAGES CXX)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
Depending on how you installed MLX, you may need to tell CMake where to
find it.
If you installed MLX with Python, then add the following to the CMake file:
.. code-block:: cmake
find_package(
Python 3.9
COMPONENTS Interpreter Development.Module
REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE MLX_ROOT)
If you installed the MLX C++ package to a system path, then CMake should be
able to find it. If you installed it to a non-standard location or CMake can't
find MLX then set ``MLX_ROOT`` to the location where MLX is installed:
.. code-block:: cmake
set(MLX_ROOT "/path/to/mlx/")
Next, instruct CMake to find MLX:
.. code-block:: cmake
find_package(MLX CONFIG REQUIRED)
Finally, add the ``example.cpp`` program as an executable and link MLX.
.. code-block:: cmake
add_executable(example example.cpp)
target_link_libraries(example PRIVATE mlx)
You can build the example with:
.. code-block:: bash
cmake -B build -DCMAKE_BUILD_TYPE=Release
cmake --build build
And run it with:
.. code-block:: bash
./build/example
Note ``find_package(MLX CONFIG REQUIRED)`` sets the following variables:
.. list-table:: Package Variables
:widths: 20 20
:header-rows: 1
* - Variable
- Description
* - MLX_FOUND
- ``True`` if MLX is found
* - MLX_INCLUDE_DIRS
- Include directory
* - MLX_LIBRARIES
- Libraries to link against
* - MLX_CXX_FLAGS
- Additional compiler flags
* - MLX_BUILD_ACCELERATE
- ``True`` if MLX was built with Accelerate
* - MLX_BUILD_METAL
- ``True`` if MLX was built with Metal

View File

@ -87,3 +87,4 @@ are the CPU and GPU.
dev/extensions dev/extensions
dev/metal_debugger dev/metal_debugger
dev/custom_metal_kernels dev/custom_metal_kernels
dev/mlx_in_cpp

View File

@ -1,3 +1,5 @@
.. _build_and_install:
Build and Install Build and Install
================= =================

View File

@ -89,6 +89,7 @@ Operations
isneginf isneginf
isposinf isposinf
issubdtype issubdtype
kron
left_shift left_shift
less less
less_equal less_equal

View File

@ -0,0 +1,22 @@
cmake_minimum_required(VERSION 3.27)
project(example LANGUAGES CXX)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
# Comment the following two commands only the MLX C++ library is installed and
# set(MLX_ROOT "/path/to/mlx") directly if needed.
find_package(
Python 3.9
COMPONENTS Interpreter Development.Module
REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE MLX_ROOT)
find_package(MLX CONFIG REQUIRED)
add_executable(example example.cpp)
target_link_libraries(example PRIVATE mlx)

View File

@ -0,0 +1,26 @@
## Build and Run
Install MLX with Python:
```bash
pip install mlx>=0.22
```
Build the C++ example:
```bash
cmake -B build -DCMAKE_BUILD_TYPE=Release
cmake --build build
```
Run the C++ example:
```
./build/example
```
which should output:
```
array([2, 4, 6], dtype=int32)
```

View File

@ -0,0 +1,14 @@
// Copyright © 2024 Apple Inc.
#include <iostream>
#include "mlx/mlx.h"
namespace mx = mlx::core;
int main() {
auto x = mx::array({1, 2, 3});
auto y = mx::array({1, 2, 3});
std::cout << x + y << std::endl;
return 0;
}

View File

@ -2,20 +2,15 @@ cmake_minimum_required(VERSION 3.27)
project(import_mlx LANGUAGES CXX) project(import_mlx LANGUAGES CXX)
# ----------------------------- Setup -----------------------------
set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
# ----------------------------- Dependencies -----------------------------
find_package( find_package(
Python 3.9 Python 3.9
COMPONENTS Interpreter Development.Module COMPONENTS Interpreter Development.Module
REQUIRED) REQUIRED)
execute_process( execute_process(
COMMAND "${Python_EXECUTABLE}" -m pip show mlx COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
COMMAND grep location
COMMAND awk "{print $4 \"/mlx\"}"
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE MLX_ROOT) OUTPUT_VARIABLE MLX_ROOT)
find_package(MLX CONFIG REQUIRED) find_package(MLX CONFIG REQUIRED)

View File

@ -3,18 +3,18 @@
#include <mlx/mlx.h> #include <mlx/mlx.h>
#include <iostream> #include <iostream>
using namespace mlx::core; namespace mx = mlx::core;
int main() { int main() {
int batch_size = 8; int batch_size = 8;
int input_dim = 32; int input_dim = 32;
// Make the input // Make the input
random::seed(42); mx::random::seed(42);
auto example_x = random::uniform({batch_size, input_dim}); auto example_x = mx::random::uniform({batch_size, input_dim});
// Import the function // Import the function
auto forward = import_function("eval_mlp.mlxfn"); auto forward = mx::import_function("eval_mlp.mlxfn");
// Call the imported function // Call the imported function
auto out = forward({example_x})[0]; auto out = forward({example_x})[0];

View File

@ -3,22 +3,22 @@
#include <mlx/mlx.h> #include <mlx/mlx.h>
#include <iostream> #include <iostream>
using namespace mlx::core; namespace mx = mlx::core;
int main() { int main() {
int batch_size = 8; int batch_size = 8;
int input_dim = 32; int input_dim = 32;
int output_dim = 10; int output_dim = 10;
auto state = import_function("init_mlp.mlxfn")({}); auto state = mx::import_function("init_mlp.mlxfn")({});
// Make the input // Make the input
random::seed(42); mx::random::seed(42);
auto example_X = random::normal({batch_size, input_dim}); auto example_X = mx::random::normal({batch_size, input_dim});
auto example_y = random::randint(0, output_dim, {batch_size}); auto example_y = mx::random::randint(0, output_dim, {batch_size});
// Import the function // Import the function
auto step = import_function("train_mlp.mlxfn"); auto step = mx::import_function("train_mlp.mlxfn");
// Call the imported function // Call the imported function
for (int it = 0; it < 100; ++it) { for (int it = 0; it < 100; ++it) {

View File

@ -914,7 +914,7 @@ inline array gather(
return gather(a, {indices}, std::vector<int>{axis}, slice_sizes, s); return gather(a, {indices}, std::vector<int>{axis}, slice_sizes, s);
} }
/** Returns Kronecker Producct given two input arrays. */ /** Compute the Kronecker product of two arrays. */
array kron(const array& a, const array& b, StreamOrDevice s = {}); array kron(const array& a, const array& b, StreamOrDevice s = {});
/** Take array slices at the given indices of the specified axis. */ /** Take array slices at the given indices of the specified axis. */

27
python/mlx/__main__.py Normal file
View File

@ -0,0 +1,27 @@
import argparse
def main() -> None:
from mlx.core import __version__
parser = argparse.ArgumentParser()
parser.add_argument(
"--version",
action="version",
version=__version__,
help="Print the version number.",
)
parser.add_argument(
"--cmake-dir",
action="store_true",
help="Print the path to the MLX CMake module directory.",
)
args = parser.parse_args()
if args.cmake_dir:
from pathlib import Path
print(Path(__file__).parent)
if __name__ == "__main__":
main()

View File

@ -1468,24 +1468,26 @@ void init_ops(nb::module_& m) {
nb::sig( nb::sig(
"def kron(a: array, b: array, *, stream: Union[None, Stream, Device] = None) -> array"), "def kron(a: array, b: array, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Compute the Kronecker product of two arrays `a` and `b`. Compute the Kronecker product of two arrays ``a`` and ``b``.
Args: Args:
a (array): The first input array a (array): The first input array.
b (array): The second input array b (array): The second input array.
stream (Union[None, Stream, Device], optional): Optional stream or device for execution. stream (Union[None, Stream, Device], optional): Optional stream or
Default is `None`. device for execution. Default: ``None``.
Returns: Returns:
array: The Kronecker product of `a` and `b`. array: The Kronecker product of ``a`` and ``b``.
Examples: Examples:
>>> import mlx >>> a = mx.array([[1, 2], [3, 4]])
>>> a = mlx.array([[1, 2], [3, 4]]) >>> b = mx.array([[0, 5], [6, 7]])
>>> b = mlx.array([[0, 5], [6, 7]]) >>> result = mx.kron(a, b)
>>> result = mlx.kron(a, b)
>>> print(result) >>> print(result)
[[ 0 5 0 10] array([[0, 5, 0, 10],
[ 6 7 12 14] [6, 7, 12, 14],
[ 0 15 0 20] [0, 15, 0, 20],
[18 21 24 28]] [18, 21, 24, 28]], dtype=int32)
)pbdoc"); )pbdoc");
m.def( m.def(
"take", "take",