Fix extensions (#1126)

* fix extensions

* title

* enable circle

* fix nanobind tag

* fix bug in doc

* try to fix config

* typo
This commit is contained in:
Awni Hannun 2024-05-16 15:36:25 -07:00 committed by GitHub
parent e78a6518fa
commit 8b76571896
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 36 additions and 26 deletions

View File

@ -49,11 +49,6 @@ jobs:
name: Run Python tests name: Run Python tests
command: | command: |
python3 -m unittest discover python/tests -v python3 -m unittest discover python/tests -v
# TODO: Reenable when extension api becomes stable
# - run:
# name: Build example extension
# command: |
# cd examples/extensions && python3 -m pip install .
- run: - run:
name: Build CPP only name: Build CPP only
command: | command: |
@ -101,11 +96,10 @@ jobs:
source env/bin/activate source env/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
# TODO: Reenable when extension api becomes stable - run:
# - run: name: Build example extension
# name: Build example extension command: |
# command: | cd examples/extensions && python3.8 -m pip install .
# cd examples/extensions && python3.11 -m pip install .
- store_test_results: - store_test_results:
path: test-results path: test-results
- run: - run:

View File

@ -1,5 +1,5 @@
Developer Documentation Custom Extensions in MLX
======================= ========================
You can extend MLX with custom operations on the CPU or GPU. This guide You can extend MLX with custom operations on the CPU or GPU. This guide
explains how to do that with a simple example. explains how to do that with a simple example.
@ -494,7 +494,7 @@ below.
auto kernel = d.get_kernel(kname.str(), "mlx_ext"); auto kernel = d.get_kernel(kname.str(), "mlx_ext");
// Prepare to encode kernel // Prepare to encode kernel
auto compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
// Kernel parameters are registered with buffer indices corresponding to // Kernel parameters are registered with buffer indices corresponding to
@ -503,11 +503,11 @@ below.
size_t nelem = out.size(); size_t nelem = out.size();
// Encode input arrays to kernel // Encode input arrays to kernel
set_array_buffer(compute_encoder, x, 0); compute_encoder.set_input_array(x, 0);
set_array_buffer(compute_encoder, y, 1); compute_encoder.set_input_array(y, 1);
// Encode output arrays to kernel // Encode output arrays to kernel
set_array_buffer(compute_encoder, out, 2); compute_encoder.set_output_array(out, 2);
// Encode alpha and beta // Encode alpha and beta
compute_encoder->setBytes(&alpha_, sizeof(float), 3); compute_encoder->setBytes(&alpha_, sizeof(float), 3);
@ -531,7 +531,7 @@ below.
// Launch the grid with the given number of threads divided among // Launch the grid with the given number of threads divided among
// the given threadgroups // the given threadgroups
compute_encoder->dispatchThreads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
We can now call the :meth:`axpby` operation on both the CPU and the GPU! We can now call the :meth:`axpby` operation on both the CPU and the GPU!
@ -825,7 +825,7 @@ Let's look at a simple script and its results:
print(f"c shape: {c.shape}") print(f"c shape: {c.shape}")
print(f"c dtype: {c.dtype}") print(f"c dtype: {c.dtype}")
print(f"c correctness: {mx.all(c == 6.0).item()}") print(f"c correct: {mx.all(c == 6.0).item()}")
Output: Output:

View File

@ -1,5 +1,5 @@
## Build the extensions ## Build
``` ```
pip install -e . pip install -e .
@ -16,3 +16,9 @@ And then run:
``` ```
python setup.py build_ext -j8 --inplace python setup.py build_ext -j8 --inplace
``` ```
## Test
```
python test.py
`

View File

@ -257,7 +257,7 @@ void Axpby::eval_gpu(
auto kernel = d.get_kernel(kname.str(), "mlx_ext"); auto kernel = d.get_kernel(kname.str(), "mlx_ext");
// Prepare to encode kernel // Prepare to encode kernel
auto compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
// Kernel parameters are registered with buffer indices corresponding to // Kernel parameters are registered with buffer indices corresponding to
@ -266,11 +266,11 @@ void Axpby::eval_gpu(
size_t nelem = out.size(); size_t nelem = out.size();
// Encode input arrays to kernel // Encode input arrays to kernel
set_array_buffer(compute_encoder, x, 0); compute_encoder.set_input_array(x, 0);
set_array_buffer(compute_encoder, y, 1); compute_encoder.set_input_array(y, 1);
// Encode output arrays to kernel // Encode output arrays to kernel
set_array_buffer(compute_encoder, out, 2); compute_encoder.set_output_array(out, 2);
// Encode alpha and beta // Encode alpha and beta
compute_encoder->setBytes(&alpha_, sizeof(float), 3); compute_encoder->setBytes(&alpha_, sizeof(float), 3);
@ -296,7 +296,7 @@ void Axpby::eval_gpu(
// Launch the grid with the given number of threads divided among // Launch the grid with the given number of threads divided among
// the given threadgroups // the given threadgroups
compute_encoder->dispatchThreads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
#else // Metal is not available #else // Metal is not available

View File

@ -2,4 +2,4 @@
import mlx.core as mx import mlx.core as mx
from .mlx_sample_extensions import * from ._ext import axpby

View File

@ -1,4 +1,4 @@
setuptools>=42 setuptools>=42
cmake>=3.24 cmake>=3.24
mlx>=0.9.0 mlx>=0.9.0
nanobind@git+https://github.com/wjakob/nanobind.git#egg=4148debcf91f5ccab0c3b8d67b5c3cabd61f407f nanobind@git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4

View File

@ -0,0 +1,10 @@
import mlx.core as mx
from mlx_sample_extensions import axpby
a = mx.ones((3, 4))
b = mx.ones((3, 4))
c = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
print(f"c shape: {c.shape}")
print(f"c dtype: {c.dtype}")
print(f"c correct: {mx.all(c == 6.0).item()}")