mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
parent
8544b42007
commit
c9d30aa6ac
121
docs/src/dev/mlx_in_cpp.rst
Normal file
121
docs/src/dev/mlx_in_cpp.rst
Normal file
@ -0,0 +1,121 @@
|
||||
.. _mlx_in_cpp:
|
||||
|
||||
Using MLX in C++
|
||||
================
|
||||
|
||||
You can use MLX in a C++ project with CMake.
|
||||
|
||||
.. note::
|
||||
|
||||
This guide is based one the following `example using MLX in C++
|
||||
<https://github.com/ml-explore/mlx/tree/main/examples/cmake_project>`_
|
||||
|
||||
First install MLX:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U mlx
|
||||
|
||||
You can also install the MLX Python package from source or just the C++
|
||||
library. For more information see the :ref:`documentation on installing MLX
|
||||
<build_and_install>`.
|
||||
|
||||
Next make an example program in ``example.cpp``:
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
|
||||
int main() {
|
||||
auto x = mx::array({1, 2, 3});
|
||||
auto y = mx::array({1, 2, 3});
|
||||
std::cout << x + y << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
The next step is to setup a CMake file in ``CMakeLists.txt``:
|
||||
|
||||
.. code-block:: cmake
|
||||
|
||||
cmake_minimum_required(VERSION 3.27)
|
||||
|
||||
project(example LANGUAGES CXX)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
|
||||
Depending on how you installed MLX, you may need to tell CMake where to
|
||||
find it.
|
||||
|
||||
If you installed MLX with Python, then add the following to the CMake file:
|
||||
|
||||
.. code-block:: cmake
|
||||
|
||||
find_package(
|
||||
Python 3.9
|
||||
COMPONENTS Interpreter Development.Module
|
||||
REQUIRED)
|
||||
execute_process(
|
||||
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
OUTPUT_VARIABLE MLX_ROOT)
|
||||
|
||||
If you installed the MLX C++ package to a system path, then CMake should be
|
||||
able to find it. If you installed it to a non-standard location or CMake can't
|
||||
find MLX then set ``MLX_ROOT`` to the location where MLX is installed:
|
||||
|
||||
.. code-block:: cmake
|
||||
|
||||
set(MLX_ROOT "/path/to/mlx/")
|
||||
|
||||
Next, instruct CMake to find MLX:
|
||||
|
||||
.. code-block:: cmake
|
||||
|
||||
find_package(MLX CONFIG REQUIRED)
|
||||
|
||||
Finally, add the ``example.cpp`` program as an executable and link MLX.
|
||||
|
||||
.. code-block:: cmake
|
||||
|
||||
add_executable(example example.cpp)
|
||||
target_link_libraries(example PRIVATE mlx)
|
||||
|
||||
You can build the example with:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
cmake -B build -DCMAKE_BUILD_TYPE=Release
|
||||
cmake --build build
|
||||
|
||||
And run it with:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
./build/example
|
||||
|
||||
Note ``find_package(MLX CONFIG REQUIRED)`` sets the following variables:
|
||||
|
||||
.. list-table:: Package Variables
|
||||
:widths: 20 20
|
||||
:header-rows: 1
|
||||
|
||||
* - Variable
|
||||
- Description
|
||||
* - MLX_FOUND
|
||||
- ``True`` if MLX is found
|
||||
* - MLX_INCLUDE_DIRS
|
||||
- Include directory
|
||||
* - MLX_LIBRARIES
|
||||
- Libraries to link against
|
||||
* - MLX_CXX_FLAGS
|
||||
- Additional compiler flags
|
||||
* - MLX_BUILD_ACCELERATE
|
||||
- ``True`` if MLX was built with Accelerate
|
||||
* - MLX_BUILD_METAL
|
||||
- ``True`` if MLX was built with Metal
|
@ -87,3 +87,4 @@ are the CPU and GPU.
|
||||
dev/extensions
|
||||
dev/metal_debugger
|
||||
dev/custom_metal_kernels
|
||||
dev/mlx_in_cpp
|
||||
|
@ -1,3 +1,5 @@
|
||||
.. _build_and_install:
|
||||
|
||||
Build and Install
|
||||
=================
|
||||
|
||||
|
@ -89,6 +89,7 @@ Operations
|
||||
isneginf
|
||||
isposinf
|
||||
issubdtype
|
||||
kron
|
||||
left_shift
|
||||
less
|
||||
less_equal
|
||||
|
22
examples/cmake_project/CMakeLists.txt
Normal file
22
examples/cmake_project/CMakeLists.txt
Normal file
@ -0,0 +1,22 @@
|
||||
cmake_minimum_required(VERSION 3.27)
|
||||
|
||||
project(example LANGUAGES CXX)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
# Comment the following two commands only the MLX C++ library is installed and
|
||||
# set(MLX_ROOT "/path/to/mlx") directly if needed.
|
||||
find_package(
|
||||
Python 3.9
|
||||
COMPONENTS Interpreter Development.Module
|
||||
REQUIRED)
|
||||
execute_process(
|
||||
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
OUTPUT_VARIABLE MLX_ROOT)
|
||||
|
||||
find_package(MLX CONFIG REQUIRED)
|
||||
|
||||
add_executable(example example.cpp)
|
||||
target_link_libraries(example PRIVATE mlx)
|
26
examples/cmake_project/README.md
Normal file
26
examples/cmake_project/README.md
Normal file
@ -0,0 +1,26 @@
|
||||
## Build and Run
|
||||
|
||||
Install MLX with Python:
|
||||
|
||||
```bash
|
||||
pip install mlx>=0.22
|
||||
```
|
||||
|
||||
Build the C++ example:
|
||||
|
||||
```bash
|
||||
cmake -B build -DCMAKE_BUILD_TYPE=Release
|
||||
cmake --build build
|
||||
```
|
||||
|
||||
Run the C++ example:
|
||||
|
||||
```
|
||||
./build/example
|
||||
```
|
||||
|
||||
which should output:
|
||||
|
||||
```
|
||||
array([2, 4, 6], dtype=int32)
|
||||
```
|
14
examples/cmake_project/example.cpp
Normal file
14
examples/cmake_project/example.cpp
Normal file
@ -0,0 +1,14 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
|
||||
int main() {
|
||||
auto x = mx::array({1, 2, 3});
|
||||
auto y = mx::array({1, 2, 3});
|
||||
std::cout << x + y << std::endl;
|
||||
return 0;
|
||||
}
|
@ -2,20 +2,15 @@ cmake_minimum_required(VERSION 3.27)
|
||||
|
||||
project(import_mlx LANGUAGES CXX)
|
||||
|
||||
# ----------------------------- Setup -----------------------------
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
# ----------------------------- Dependencies -----------------------------
|
||||
find_package(
|
||||
Python 3.9
|
||||
COMPONENTS Interpreter Development.Module
|
||||
REQUIRED)
|
||||
execute_process(
|
||||
COMMAND "${Python_EXECUTABLE}" -m pip show mlx
|
||||
COMMAND grep location
|
||||
COMMAND awk "{print $4 \"/mlx\"}"
|
||||
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
OUTPUT_VARIABLE MLX_ROOT)
|
||||
find_package(MLX CONFIG REQUIRED)
|
||||
|
@ -3,18 +3,18 @@
|
||||
#include <mlx/mlx.h>
|
||||
#include <iostream>
|
||||
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
int main() {
|
||||
int batch_size = 8;
|
||||
int input_dim = 32;
|
||||
|
||||
// Make the input
|
||||
random::seed(42);
|
||||
auto example_x = random::uniform({batch_size, input_dim});
|
||||
mx::random::seed(42);
|
||||
auto example_x = mx::random::uniform({batch_size, input_dim});
|
||||
|
||||
// Import the function
|
||||
auto forward = import_function("eval_mlp.mlxfn");
|
||||
auto forward = mx::import_function("eval_mlp.mlxfn");
|
||||
|
||||
// Call the imported function
|
||||
auto out = forward({example_x})[0];
|
||||
|
@ -3,22 +3,22 @@
|
||||
#include <mlx/mlx.h>
|
||||
#include <iostream>
|
||||
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
int main() {
|
||||
int batch_size = 8;
|
||||
int input_dim = 32;
|
||||
int output_dim = 10;
|
||||
|
||||
auto state = import_function("init_mlp.mlxfn")({});
|
||||
auto state = mx::import_function("init_mlp.mlxfn")({});
|
||||
|
||||
// Make the input
|
||||
random::seed(42);
|
||||
auto example_X = random::normal({batch_size, input_dim});
|
||||
auto example_y = random::randint(0, output_dim, {batch_size});
|
||||
mx::random::seed(42);
|
||||
auto example_X = mx::random::normal({batch_size, input_dim});
|
||||
auto example_y = mx::random::randint(0, output_dim, {batch_size});
|
||||
|
||||
// Import the function
|
||||
auto step = import_function("train_mlp.mlxfn");
|
||||
auto step = mx::import_function("train_mlp.mlxfn");
|
||||
|
||||
// Call the imported function
|
||||
for (int it = 0; it < 100; ++it) {
|
||||
|
@ -914,7 +914,7 @@ inline array gather(
|
||||
return gather(a, {indices}, std::vector<int>{axis}, slice_sizes, s);
|
||||
}
|
||||
|
||||
/** Returns Kronecker Producct given two input arrays. */
|
||||
/** Compute the Kronecker product of two arrays. */
|
||||
array kron(const array& a, const array& b, StreamOrDevice s = {});
|
||||
|
||||
/** Take array slices at the given indices of the specified axis. */
|
||||
|
27
python/mlx/__main__.py
Normal file
27
python/mlx/__main__.py
Normal file
@ -0,0 +1,27 @@
|
||||
import argparse
|
||||
|
||||
|
||||
def main() -> None:
|
||||
from mlx.core import __version__
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--version",
|
||||
action="version",
|
||||
version=__version__,
|
||||
help="Print the version number.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cmake-dir",
|
||||
action="store_true",
|
||||
help="Print the path to the MLX CMake module directory.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if args.cmake_dir:
|
||||
from pathlib import Path
|
||||
|
||||
print(Path(__file__).parent)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1468,24 +1468,26 @@ void init_ops(nb::module_& m) {
|
||||
nb::sig(
|
||||
"def kron(a: array, b: array, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Compute the Kronecker product of two arrays `a` and `b`.
|
||||
Compute the Kronecker product of two arrays ``a`` and ``b``.
|
||||
|
||||
Args:
|
||||
a (array): The first input array
|
||||
b (array): The second input array
|
||||
stream (Union[None, Stream, Device], optional): Optional stream or device for execution.
|
||||
Default is `None`.
|
||||
a (array): The first input array.
|
||||
b (array): The second input array.
|
||||
stream (Union[None, Stream, Device], optional): Optional stream or
|
||||
device for execution. Default: ``None``.
|
||||
|
||||
Returns:
|
||||
array: The Kronecker product of `a` and `b`.
|
||||
array: The Kronecker product of ``a`` and ``b``.
|
||||
|
||||
Examples:
|
||||
>>> import mlx
|
||||
>>> a = mlx.array([[1, 2], [3, 4]])
|
||||
>>> b = mlx.array([[0, 5], [6, 7]])
|
||||
>>> result = mlx.kron(a, b)
|
||||
>>> a = mx.array([[1, 2], [3, 4]])
|
||||
>>> b = mx.array([[0, 5], [6, 7]])
|
||||
>>> result = mx.kron(a, b)
|
||||
>>> print(result)
|
||||
[[ 0 5 0 10]
|
||||
[ 6 7 12 14]
|
||||
[ 0 15 0 20]
|
||||
[18 21 24 28]]
|
||||
array([[0, 5, 0, 10],
|
||||
[6, 7, 12, 14],
|
||||
[0, 15, 0, 20],
|
||||
[18, 21, 24, 28]], dtype=int32)
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"take",
|
||||
|
Loading…
Reference in New Issue
Block a user