From 8b765718966383f94f2169af40d922fb291a5a39 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 16 May 2024 15:36:25 -0700 Subject: [PATCH] Fix extensions (#1126) * fix extensions * title * enable circle * fix nanobind tag * fix bug in doc * try to fix config * typo --- .circleci/config.yml | 14 ++++---------- docs/src/dev/extensions.rst | 16 ++++++++-------- examples/extensions/README.md | 8 +++++++- examples/extensions/axpby/axpby.cpp | 10 +++++----- .../extensions/mlx_sample_extensions/__init__.py | 2 +- examples/extensions/requirements.txt | 2 +- examples/extensions/test.py | 10 ++++++++++ 7 files changed, 36 insertions(+), 26 deletions(-) create mode 100644 examples/extensions/test.py diff --git a/.circleci/config.yml b/.circleci/config.yml index d40b10dc2..e9eebff69 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -49,11 +49,6 @@ jobs: name: Run Python tests command: | python3 -m unittest discover python/tests -v - # TODO: Reenable when extension api becomes stable - # - run: - # name: Build example extension - # command: | - # cd examples/extensions && python3 -m pip install . - run: name: Build CPP only command: | @@ -101,11 +96,10 @@ jobs: source env/bin/activate LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu - # TODO: Reenable when extension api becomes stable - # - run: - # name: Build example extension - # command: | - # cd examples/extensions && python3.11 -m pip install . + - run: + name: Build example extension + command: | + cd examples/extensions && python3.8 -m pip install . - store_test_results: path: test-results - run: diff --git a/docs/src/dev/extensions.rst b/docs/src/dev/extensions.rst index acf41a773..9a2be90cd 100644 --- a/docs/src/dev/extensions.rst +++ b/docs/src/dev/extensions.rst @@ -1,5 +1,5 @@ -Developer Documentation -======================= +Custom Extensions in MLX +======================== You can extend MLX with custom operations on the CPU or GPU. This guide explains how to do that with a simple example. @@ -494,7 +494,7 @@ below. auto kernel = d.get_kernel(kname.str(), "mlx_ext"); // Prepare to encode kernel - auto compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder->setComputePipelineState(kernel); // Kernel parameters are registered with buffer indices corresponding to @@ -503,11 +503,11 @@ below. size_t nelem = out.size(); // Encode input arrays to kernel - set_array_buffer(compute_encoder, x, 0); - set_array_buffer(compute_encoder, y, 1); + compute_encoder.set_input_array(x, 0); + compute_encoder.set_input_array(y, 1); // Encode output arrays to kernel - set_array_buffer(compute_encoder, out, 2); + compute_encoder.set_output_array(out, 2); // Encode alpha and beta compute_encoder->setBytes(&alpha_, sizeof(float), 3); @@ -531,7 +531,7 @@ below. // Launch the grid with the given number of threads divided among // the given threadgroups - compute_encoder->dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatchThreads(grid_dims, group_dims); } We can now call the :meth:`axpby` operation on both the CPU and the GPU! @@ -825,7 +825,7 @@ Let's look at a simple script and its results: print(f"c shape: {c.shape}") print(f"c dtype: {c.dtype}") - print(f"c correctness: {mx.all(c == 6.0).item()}") + print(f"c correct: {mx.all(c == 6.0).item()}") Output: diff --git a/examples/extensions/README.md b/examples/extensions/README.md index 17582bc0f..1cb113459 100644 --- a/examples/extensions/README.md +++ b/examples/extensions/README.md @@ -1,5 +1,5 @@ -## Build the extensions +## Build ``` pip install -e . @@ -16,3 +16,9 @@ And then run: ``` python setup.py build_ext -j8 --inplace ``` + +## Test + +``` +python test.py +` diff --git a/examples/extensions/axpby/axpby.cpp b/examples/extensions/axpby/axpby.cpp index bfd308e7c..57c7d7900 100644 --- a/examples/extensions/axpby/axpby.cpp +++ b/examples/extensions/axpby/axpby.cpp @@ -257,7 +257,7 @@ void Axpby::eval_gpu( auto kernel = d.get_kernel(kname.str(), "mlx_ext"); // Prepare to encode kernel - auto compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder->setComputePipelineState(kernel); // Kernel parameters are registered with buffer indices corresponding to @@ -266,11 +266,11 @@ void Axpby::eval_gpu( size_t nelem = out.size(); // Encode input arrays to kernel - set_array_buffer(compute_encoder, x, 0); - set_array_buffer(compute_encoder, y, 1); + compute_encoder.set_input_array(x, 0); + compute_encoder.set_input_array(y, 1); // Encode output arrays to kernel - set_array_buffer(compute_encoder, out, 2); + compute_encoder.set_output_array(out, 2); // Encode alpha and beta compute_encoder->setBytes(&alpha_, sizeof(float), 3); @@ -296,7 +296,7 @@ void Axpby::eval_gpu( // Launch the grid with the given number of threads divided among // the given threadgroups - compute_encoder->dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatchThreads(grid_dims, group_dims); } #else // Metal is not available diff --git a/examples/extensions/mlx_sample_extensions/__init__.py b/examples/extensions/mlx_sample_extensions/__init__.py index 4f3acce9e..b56a54fd2 100644 --- a/examples/extensions/mlx_sample_extensions/__init__.py +++ b/examples/extensions/mlx_sample_extensions/__init__.py @@ -2,4 +2,4 @@ import mlx.core as mx -from .mlx_sample_extensions import * +from ._ext import axpby diff --git a/examples/extensions/requirements.txt b/examples/extensions/requirements.txt index 01a7d3864..cecbc3338 100644 --- a/examples/extensions/requirements.txt +++ b/examples/extensions/requirements.txt @@ -1,4 +1,4 @@ setuptools>=42 cmake>=3.24 mlx>=0.9.0 -nanobind@git+https://github.com/wjakob/nanobind.git#egg=4148debcf91f5ccab0c3b8d67b5c3cabd61f407f +nanobind@git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4 diff --git a/examples/extensions/test.py b/examples/extensions/test.py new file mode 100644 index 000000000..f00a72e85 --- /dev/null +++ b/examples/extensions/test.py @@ -0,0 +1,10 @@ +import mlx.core as mx +from mlx_sample_extensions import axpby + +a = mx.ones((3, 4)) +b = mx.ones((3, 4)) +c = axpby(a, b, 4.0, 2.0, stream=mx.cpu) + +print(f"c shape: {c.shape}") +print(f"c dtype: {c.dtype}") +print(f"c correct: {mx.all(c == 6.0).item()}")