mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Fix extensions (#1126)
* fix extensions * title * enable circle * fix nanobind tag * fix bug in doc * try to fix config * typo
This commit is contained in:
parent
e78a6518fa
commit
8b76571896
@ -49,11 +49,6 @@ jobs:
|
|||||||
name: Run Python tests
|
name: Run Python tests
|
||||||
command: |
|
command: |
|
||||||
python3 -m unittest discover python/tests -v
|
python3 -m unittest discover python/tests -v
|
||||||
# TODO: Reenable when extension api becomes stable
|
|
||||||
# - run:
|
|
||||||
# name: Build example extension
|
|
||||||
# command: |
|
|
||||||
# cd examples/extensions && python3 -m pip install .
|
|
||||||
- run:
|
- run:
|
||||||
name: Build CPP only
|
name: Build CPP only
|
||||||
command: |
|
command: |
|
||||||
@ -101,11 +96,10 @@ jobs:
|
|||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
||||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
|
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||||
# TODO: Reenable when extension api becomes stable
|
- run:
|
||||||
# - run:
|
name: Build example extension
|
||||||
# name: Build example extension
|
command: |
|
||||||
# command: |
|
cd examples/extensions && python3.8 -m pip install .
|
||||||
# cd examples/extensions && python3.11 -m pip install .
|
|
||||||
- store_test_results:
|
- store_test_results:
|
||||||
path: test-results
|
path: test-results
|
||||||
- run:
|
- run:
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
Developer Documentation
|
Custom Extensions in MLX
|
||||||
=======================
|
========================
|
||||||
|
|
||||||
You can extend MLX with custom operations on the CPU or GPU. This guide
|
You can extend MLX with custom operations on the CPU or GPU. This guide
|
||||||
explains how to do that with a simple example.
|
explains how to do that with a simple example.
|
||||||
@ -494,7 +494,7 @@ below.
|
|||||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
||||||
|
|
||||||
// Prepare to encode kernel
|
// Prepare to encode kernel
|
||||||
auto compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
// Kernel parameters are registered with buffer indices corresponding to
|
// Kernel parameters are registered with buffer indices corresponding to
|
||||||
@ -503,11 +503,11 @@ below.
|
|||||||
size_t nelem = out.size();
|
size_t nelem = out.size();
|
||||||
|
|
||||||
// Encode input arrays to kernel
|
// Encode input arrays to kernel
|
||||||
set_array_buffer(compute_encoder, x, 0);
|
compute_encoder.set_input_array(x, 0);
|
||||||
set_array_buffer(compute_encoder, y, 1);
|
compute_encoder.set_input_array(y, 1);
|
||||||
|
|
||||||
// Encode output arrays to kernel
|
// Encode output arrays to kernel
|
||||||
set_array_buffer(compute_encoder, out, 2);
|
compute_encoder.set_output_array(out, 2);
|
||||||
|
|
||||||
// Encode alpha and beta
|
// Encode alpha and beta
|
||||||
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
|
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
|
||||||
@ -531,7 +531,7 @@ below.
|
|||||||
|
|
||||||
// Launch the grid with the given number of threads divided among
|
// Launch the grid with the given number of threads divided among
|
||||||
// the given threadgroups
|
// the given threadgroups
|
||||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
We can now call the :meth:`axpby` operation on both the CPU and the GPU!
|
We can now call the :meth:`axpby` operation on both the CPU and the GPU!
|
||||||
@ -825,7 +825,7 @@ Let's look at a simple script and its results:
|
|||||||
|
|
||||||
print(f"c shape: {c.shape}")
|
print(f"c shape: {c.shape}")
|
||||||
print(f"c dtype: {c.dtype}")
|
print(f"c dtype: {c.dtype}")
|
||||||
print(f"c correctness: {mx.all(c == 6.0).item()}")
|
print(f"c correct: {mx.all(c == 6.0).item()}")
|
||||||
|
|
||||||
Output:
|
Output:
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
|
|
||||||
## Build the extensions
|
## Build
|
||||||
|
|
||||||
```
|
```
|
||||||
pip install -e .
|
pip install -e .
|
||||||
@ -16,3 +16,9 @@ And then run:
|
|||||||
```
|
```
|
||||||
python setup.py build_ext -j8 --inplace
|
python setup.py build_ext -j8 --inplace
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Test
|
||||||
|
|
||||||
|
```
|
||||||
|
python test.py
|
||||||
|
`
|
||||||
|
@ -257,7 +257,7 @@ void Axpby::eval_gpu(
|
|||||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
||||||
|
|
||||||
// Prepare to encode kernel
|
// Prepare to encode kernel
|
||||||
auto compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
// Kernel parameters are registered with buffer indices corresponding to
|
// Kernel parameters are registered with buffer indices corresponding to
|
||||||
@ -266,11 +266,11 @@ void Axpby::eval_gpu(
|
|||||||
size_t nelem = out.size();
|
size_t nelem = out.size();
|
||||||
|
|
||||||
// Encode input arrays to kernel
|
// Encode input arrays to kernel
|
||||||
set_array_buffer(compute_encoder, x, 0);
|
compute_encoder.set_input_array(x, 0);
|
||||||
set_array_buffer(compute_encoder, y, 1);
|
compute_encoder.set_input_array(y, 1);
|
||||||
|
|
||||||
// Encode output arrays to kernel
|
// Encode output arrays to kernel
|
||||||
set_array_buffer(compute_encoder, out, 2);
|
compute_encoder.set_output_array(out, 2);
|
||||||
|
|
||||||
// Encode alpha and beta
|
// Encode alpha and beta
|
||||||
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
|
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
|
||||||
@ -296,7 +296,7 @@ void Axpby::eval_gpu(
|
|||||||
|
|
||||||
// Launch the grid with the given number of threads divided among
|
// Launch the grid with the given number of threads divided among
|
||||||
// the given threadgroups
|
// the given threadgroups
|
||||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // Metal is not available
|
#else // Metal is not available
|
||||||
|
@ -2,4 +2,4 @@
|
|||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
from .mlx_sample_extensions import *
|
from ._ext import axpby
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
setuptools>=42
|
setuptools>=42
|
||||||
cmake>=3.24
|
cmake>=3.24
|
||||||
mlx>=0.9.0
|
mlx>=0.9.0
|
||||||
nanobind@git+https://github.com/wjakob/nanobind.git#egg=4148debcf91f5ccab0c3b8d67b5c3cabd61f407f
|
nanobind@git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
|
||||||
|
10
examples/extensions/test.py
Normal file
10
examples/extensions/test.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
import mlx.core as mx
|
||||||
|
from mlx_sample_extensions import axpby
|
||||||
|
|
||||||
|
a = mx.ones((3, 4))
|
||||||
|
b = mx.ones((3, 4))
|
||||||
|
c = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
|
||||||
|
|
||||||
|
print(f"c shape: {c.shape}")
|
||||||
|
print(f"c dtype: {c.dtype}")
|
||||||
|
print(f"c correct: {mx.all(c == 6.0).item()}")
|
Loading…
Reference in New Issue
Block a user