Extensions (#962)

* start to fix extensions

* mostly fixed extensions

* fix extension build

* couple more nits
This commit is contained in:
Awni Hannun 2024-04-09 08:50:36 -07:00 committed by GitHub
parent 42afe27e12
commit b63ef10a7f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 292 additions and 309 deletions

View File

@ -1,24 +1,16 @@
Developer Documentation Developer Documentation
======================= =======================
MLX provides a open and flexible backend to which users may add operations You can extend MLX with custom operations on the CPU or GPU. This guide
and specialized implementations without much hassle. While the library supplies explains how to do that with a simple example.
efficient operations that can be used and composed for any number of
applications, there may arise cases where new functionalities or highly
optimized implementations are needed. For such cases, you may design and
implement your own operations that link to and build on top of :mod:`mlx.core`.
We will introduce the inner-workings of MLX and go over a simple example to
learn the steps involved in adding new operations to MLX with your own CPU
and GPU implementations.
Introducing the Example Introducing the Example
----------------------- -----------------------
Let's say that you would like an operation that takes in two arrays, Let's say you would like an operation that takes in two arrays, ``x`` and
``x`` and ``y``, scales them both by some coefficients ``alpha`` and ``beta`` ``y``, scales them both by coefficients ``alpha`` and ``beta`` respectively,
respectively, and then adds them together to get the result and then adds them together to get the result ``z = alpha * x + beta * y``.
``z = alpha * x + beta * y``. Well, you can very easily do that by just You can do that in MLX directly:
writing out a function as follows:
.. code-block:: python .. code-block:: python
@ -27,44 +19,35 @@ writing out a function as follows:
def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array: def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
return alpha * x + beta * y return alpha * x + beta * y
This function performs that operation while leaving the implementations and This function performs that operation while leaving the implementation and
differentiation to MLX. function transformations to MLX.
However, you work with vector math libraries often and realize that the However you may need to customize the underlying implementation, perhaps to
``axpby`` routine defines the same operation ``Y = (alpha * X) + (beta * Y)``. make it faster or for custom differentiation. In this tutorial we will go
You would really like the part of your applications that does this operation through adding custom extensions. It will cover:
on the CPU to be very fast - so you decide that you want it to rely on the
``axpby`` routine provided by the Accelerate_ framework. Continuing to impose
our assumptions on to you, let's also assume that you want to learn how to add
your own implementation for the gradients of your new operation while going
over the ins-and-outs of the MLX framework.
Well, what a coincidence! You are in the right place. Over the course of this * The structure of the MLX library.
example, we will learn: * Implementing a CPU operation that redirects to Accelerate_ when appropriate.
* Implementing a GPU operation using metal.
* The structure of the MLX library from the frontend API to the backend implementations. * Adding the ``vjp`` and ``jvp`` function transformation.
* How to implement your own CPU backend that redirects to Accelerate_ when appropriate (and a fallback if needed). * Building a custom extension and binding it to python.
* How to implement your own GPU implementation using metal.
* How to add your own ``vjp`` and ``jvp``.
* How to build your implementations, link them to MLX, and bind them to python.
Operations and Primitives Operations and Primitives
------------------------- -------------------------
In one sentence, operations in MLX build the computation graph, and primitives Operations in MLX build the computation graph. Primitives provide the rules for
provide the rules for evaluation and transformations of said graph. Let's start evaluating and transforming the graph. Let's start by discussing operations in
by discussing operations in more detail. more detail.
Operations Operations
^^^^^^^^^^^ ^^^^^^^^^^^
Operations are the frontend functions that operate on arrays. They are defined Operations are the front-end functions that operate on arrays. They are defined
in the C++ API (:ref:`cpp_ops`) and then we provide bindings to these in the C++ API (:ref:`cpp_ops`), and the Python API (:ref:`ops`) binds them.
operations in the Python API (:ref:`ops`).
We would like an operation, :meth:`axpby` that takes in two arrays ``x`` and ``y``, We would like an operation, :meth:`axpby` that takes in two arrays ``x`` and
and two scalars, ``alpha`` and ``beta``. This is how we would define it in the ``y``, and two scalars, ``alpha`` and ``beta``. This is how to define it in
C++ API: C++:
.. code-block:: C++ .. code-block:: C++
@ -83,10 +66,7 @@ C++ API:
StreamOrDevice s = {} // Stream on which to schedule the operation StreamOrDevice s = {} // Stream on which to schedule the operation
); );
The simplest way to this operation is in terms of existing operations:
This operation itself can call other operations within it if needed. So, the
simplest way to go about implementing this operation would be do so in terms
of existing operations.
.. code-block:: C++ .. code-block:: C++
@ -100,25 +80,23 @@ of existing operations.
// Scale x and y on the provided stream // Scale x and y on the provided stream
auto ax = multiply(array(alpha), x, s); auto ax = multiply(array(alpha), x, s);
auto by = multiply(array(beta), y, s); auto by = multiply(array(beta), y, s);
// Add and return // Add and return
return add(ax, by, s); return add(ax, by, s);
} }
However, as we discussed earlier, this is not our goal. The operations themselves The operations themselves do not contain the implementations that act on the
do not contain the implementations that act on the data, nor do they contain the data, nor do they contain the rules of transformations. Rather, they are an
rules of transformations. Rather, they are an easy to use interface that build easy to use interface that use :class:`Primitive` building blocks.
on top of the building blocks we call :class:`Primitive`.
Primitives Primitives
^^^^^^^^^^^ ^^^^^^^^^^^
A :class:`Primitive` is part of the computation graph of an :class:`array`. It A :class:`Primitive` is part of the computation graph of an :class:`array`. It
defines how to create an output given a set of input :class:`array` . Further, defines how to create outputs arrays given a input arrays. Further, a
a :class:`Primitive` is a class that contains rules on how it is evaluated :class:`Primitive` has methods to run on the CPU or GPU and for function
on the CPU or GPU, and how it acts under transformations such as ``vjp`` and transformations such as ``vjp`` and ``jvp``. Lets go back to our example to be
``jvp``. These words on their own can be a bit abstract, so lets take a step more concrete:
back and go to our example to give ourselves a more concrete image.
.. code-block:: C++ .. code-block:: C++
@ -134,11 +112,15 @@ back and go to our example to give ourselves a more concrete image.
* To avoid unnecessary allocations, the evaluation function * To avoid unnecessary allocations, the evaluation function
* is responsible for allocating space for the array. * is responsible for allocating space for the array.
*/ */
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(
void eval_gpu(const std::vector<array>& inputs, array& out) override; const std::vector<array>& inputs,
std::vector<array>& outputs) override;
void eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) override;
/** The Jacobian-vector product. */ /** The Jacobian-vector product. */
array jvp( std::vector<array> jvp(
const std::vector<array>& primals, const std::vector<array>& primals,
const std::vector<array>& tangents, const std::vector<array>& tangents,
const std::vector<int>& argnums) override; const std::vector<int>& argnums) override;
@ -147,7 +129,8 @@ back and go to our example to give ourselves a more concrete image.
std::vector<array> vjp( std::vector<array> vjp(
const std::vector<array>& primals, const std::vector<array>& primals,
const array& cotan, const array& cotan,
const std::vector<int>& argnums) override; const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
/** /**
* The primitive must know how to vectorize itself across * The primitive must know how to vectorize itself across
@ -155,7 +138,7 @@ back and go to our example to give ourselves a more concrete image.
* representing the vectorized computation and the axis which * representing the vectorized computation and the axis which
* corresponds to the output vectorized dimension. * corresponds to the output vectorized dimension.
*/ */
std::pair<array, int> vmap( virtual std::pair<std::vector<array>, std::vector<int>> vmap(
const std::vector<array>& inputs, const std::vector<array>& inputs,
const std::vector<int>& axes) override; const std::vector<int>& axes) override;
@ -175,22 +158,22 @@ back and go to our example to give ourselves a more concrete image.
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
}; };
The :class:`Axpby` class derives from the base :class:`Primitive` class and The :class:`Axpby` class derives from the base :class:`Primitive` class. The
follows the above demonstrated interface. :class:`Axpby` treats ``alpha`` and :class:`Axpby` treats ``alpha`` and ``beta`` as parameters. It then provides
``beta`` as parameters. It then provides implementations of how the array ``out`` implementations of how the output array is produced given the inputs through
is produced given ``inputs`` through :meth:`Axpby::eval_cpu` and :meth:`Axpby::eval_cpu` and :meth:`Axpby::eval_gpu`. It also provides rules
:meth:`Axpby::eval_gpu`. Further, it provides rules of transformations in of transformations in :meth:`Axpby::jvp`, :meth:`Axpby::vjp`, and
:meth:`Axpby::jvp`, :meth:`Axpby::vjp`, and :meth:`Axpby::vmap`. :meth:`Axpby::vmap`.
Using the Primitives Using the Primitive
^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^
Operations can use this :class:`Primitive` to add a new :class:`array` to Operations can use this :class:`Primitive` to add a new :class:`array` to the
the computation graph. An :class:`array` can be constructed by providing its computation graph. An :class:`array` can be constructed by providing its data
data type, shape, the :class:`Primitive` that computes it, and the type, shape, the :class:`Primitive` that computes it, and the :class:`array`
:class:`array` inputs that are passed to the primitive. inputs that are passed to the primitive.
Let's re-implement our operation now in terms of our :class:`Axpby` primitive. Let's reimplement our operation now in terms of our :class:`Axpby` primitive.
.. code-block:: C++ .. code-block:: C++
@ -238,27 +221,26 @@ This operation now handles the following:
Implementing the Primitive Implementing the Primitive
-------------------------- --------------------------
No computation happens when we call the operation alone. In effect, the No computation happens when we call the operation alone. The operation only
operation only builds the computation graph. When we evaluate the output builds the computation graph. When we evaluate the output array, MLX schedules
array, MLX schedules the execution of the computation graph, and calls the execution of the computation graph, and calls :meth:`Axpby::eval_cpu` or
:meth:`Axpby::eval_cpu` or :meth:`Axpby::eval_gpu` depending on the :meth:`Axpby::eval_gpu` depending on the stream/device specified by the user.
stream/device specified by the user.
.. warning:: .. warning::
When :meth:`Primitive::eval_cpu` or :meth:`Primitive::eval_gpu` are called, When :meth:`Primitive::eval_cpu` or :meth:`Primitive::eval_gpu` are called,
no memory has been allocated for the output array. It falls on the implementation no memory has been allocated for the output array. It falls on the implementation
of these functions to allocate memory as needed of these functions to allocate memory as needed.
Implementing the CPU Backend Implementing the CPU Back-end
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Let's start by trying to implement a naive and generic version of Let's start by implementing a naive and generic version of
:meth:`Axpby::eval_cpu`. We declared this as a private member function of :meth:`Axpby::eval_cpu`. We declared this as a private member function of
:class:`Axpby` earlier called :meth:`Axpby::eval`. :class:`Axpby` earlier called :meth:`Axpby::eval`.
Our naive method will go over each element of the output array, find the Our naive method will go over each element of the output array, find the
corresponding input elements of ``x`` and ``y`` and perform the operation corresponding input elements of ``x`` and ``y`` and perform the operation
pointwise. This is captured in the templated function :meth:`axpby_impl`. point-wise. This is captured in the templated function :meth:`axpby_impl`.
.. code-block:: C++ .. code-block:: C++
@ -296,19 +278,19 @@ pointwise. This is captured in the templated function :meth:`axpby_impl`.
} }
} }
Now, we would like our implementation to be able to do this pointwise operation Our implementation should work for all incoming floating point arrays.
for all incoming floating point arrays. Accordingly, we add dispatches for Accordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and
``float32``, ``float16``, ``bfloat16`` and ``complex64``. We throw an error ``complex64``. We throw an error if we encounter an unexpected type.
if we encounter an unexpected type.
.. code-block:: C++ .. code-block:: C++
/** Fall back implementation for evaluation on CPU */ /** Fall back implementation for evaluation on CPU */
void Axpby::eval(const std::vector<array>& inputs, array& out) { void Axpby::eval(
// Check the inputs (registered in the op while constructing the out array) const std::vector<array>& inputs,
assert(inputs.size() == 2); const std::vector<array>& outputs) {
auto& x = inputs[0]; auto& x = inputs[0];
auto& y = inputs[1]; auto& y = inputs[1];
auto& out = outputs[0];
// Dispatch to the correct dtype // Dispatch to the correct dtype
if (out.dtype() == float32) { if (out.dtype() == float32) {
@ -321,28 +303,26 @@ if we encounter an unexpected type.
return axpby_impl<complex64_t>(x, y, out, alpha_, beta_); return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
} else { } else {
throw std::runtime_error( throw std::runtime_error(
"Axpby is only supported for floating point types."); "[Axpby] Only supports floating point types.");
} }
} }
We have a fallback implementation! Now, to do what we are really here to do. This is good as a fallback implementation. We can use the ``axpby`` routine
Remember we wanted to use the ``axpby`` routine provided by the Accelerate_ provided by the Accelerate_ framework for a faster implementation in certain
framework? Well, there are 3 complications to keep in mind: cases:
#. Accelerate does not provide implementations of ``axpby`` for half precision #. Accelerate does not provide implementations of ``axpby`` for half precision
floats. We can only direct to it for ``float32`` types floats. We can only use it for ``float32`` types.
#. Accelerate assumes the inputs ``x`` and ``y`` are contiguous and all elements #. Accelerate assumes the inputs ``x`` and ``y`` are contiguous and all
have fixed strides between them. Possibly due to broadcasts and transposes, elements have fixed strides between them. We only direct to Accelerate
we aren't guaranteed that the inputs fit this requirement. We can if both ``x`` and ``y`` are row contiguous or column contiguous.
only direct to Accelerate if both ``x`` and ``y`` are row contiguous or #. Accelerate performs the routine ``Y = (alpha * X) + (beta * Y)`` in-place.
column contiguous. MLX expects to write the output to a new array. We must copy the elements
#. Accelerate performs the routine ``Y = (alpha * X) + (beta * Y)`` inplace. of ``y`` into the output and use that as an input to ``axpby``.
MLX expects to write out the answer to a new array. We must copy the elements
of ``y`` into the output array and use that as an input to ``axpby``
Let's write out an implementation that uses Accelerate in the right conditions. Let's write an implementation that uses Accelerate in the right conditions.
It must simply allocate data for the output, copy elements of ``y`` into it, It allocates data for the output, copies ``y`` into it, and then calls the
and then call the :meth:`catlas_saxpby` from accelerate. :func:`catlas_saxpby` from accelerate.
.. code-block:: C++ .. code-block:: C++
@ -356,17 +336,7 @@ and then call the :meth:`catlas_saxpby` from accelerate.
// Accelerate library provides catlas_saxpby which does // Accelerate library provides catlas_saxpby which does
// Y = (alpha * X) + (beta * Y) in place // Y = (alpha * X) + (beta * Y) in place
// To use it, we first copy the data in y over to the output array // To use it, we first copy the data in y over to the output array
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// This specialization requires both x and y be contiguous in the same mode
// i.e: corresponding linear indices in both point to corresponding elements
// The data in the output array is allocated to match the strides in y
// such that x, y, and out are contiguous in the same mode and
// no transposition is needed
out.set_data(
allocator::malloc_or_wait(y.data_size() * out.itemsize()),
y.data_size(),
y.strides(),
y.flags());
// We then copy over the elements using the contiguous vector specialization // We then copy over the elements using the contiguous vector specialization
copy_inplace(y, out, CopyType::Vector); copy_inplace(y, out, CopyType::Vector);
@ -389,18 +359,20 @@ and then call the :meth:`catlas_saxpby` from accelerate.
/* INCY = */ 1); /* INCY = */ 1);
} }
Great! But what about the inputs that do not fit the criteria for accelerate? For inputs that do not fit the criteria for accelerate, we fall back to
Luckily, we can always just direct back to :meth:`Axpby::eval`. :meth:`Axpby::eval`. With this in mind, let's finish our
:meth:`Axpby::eval_cpu`.
With this in mind, lets finally implement our :meth:`Axpby::eval_cpu`.
.. code-block:: C++ .. code-block:: C++
/** Evaluate primitive on CPU using accelerate specializations */ /** Evaluate primitive on CPU using accelerate specializations */
void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) { void Axpby::eval_cpu(
const std::vector<array>& inputs,
const std::vector<array>& outputs) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& x = inputs[0]; auto& x = inputs[0];
auto& y = inputs[1]; auto& y = inputs[1];
auto& out = outputs[0];
// Accelerate specialization for contiguous single precision float arrays // Accelerate specialization for contiguous single precision float arrays
if (out.dtype() == float32 && if (out.dtype() == float32 &&
@ -410,35 +382,33 @@ With this in mind, lets finally implement our :meth:`Axpby::eval_cpu`.
return; return;
} }
// Fall back to common backend if specializations are not available // Fall back to common back-end if specializations are not available
eval(inputs, out); eval(inputs, outputs);
} }
We have now hit a milestone! Just this much is enough to run the operation Just this much is enough to run the operation :meth:`axpby` on a CPU stream! If
:meth:`axpby` on a CPU stream! you do not plan on running the operation on the GPU or using transforms on
computation graphs that contain :class:`Axpby`, you can stop implementing the
primitive here and enjoy the speed-ups you get from the Accelerate library.
If you do not plan on running the operation on the GPU or using transforms on Implementing the GPU Back-end
computation graphs that contain :class:`Axpby`, you can stop implementing the
primitive here and enjoy the speed-ups you get from the Accelerate library.
Implementing the GPU Backend
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Apple silicon devices address their GPUs using the Metal_ shading language, and Apple silicon devices address their GPUs using the Metal_ shading language, and
all GPU kernels in MLX are written using metal. GPU kernels in MLX are written using Metal.
.. note:: .. note::
Here are some helpful resources if you are new to metal! Here are some helpful resources if you are new to Metal:
* A walkthrough of the metal compute pipeline: `Metal Example`_ * A walkthrough of the metal compute pipeline: `Metal Example`_
* Documentation for metal shading language: `Metal Specification`_ * Documentation for metal shading language: `Metal Specification`_
* Using metal from C++: `Metal-cpp`_ * Using metal from C++: `Metal-cpp`_
Let's keep the GPU algorithm simple. We will launch exactly as many threads Let's keep the GPU kernel simple. We will launch exactly as many threads as
as there are elements in the output. Each thread will pick the element it needs there are elements in the output. Each thread will pick the element it needs
from ``x`` and ``y``, do the pointwise operation, and then update its assigned from ``x`` and ``y``, do the point-wise operation, and update its assigned
element in the output. element in the output.
.. code-block:: C++ .. code-block:: C++
@ -457,15 +427,14 @@ element in the output.
// Convert linear indices to offsets in array // Convert linear indices to offsets in array
auto x_offset = elem_to_loc(index, shape, x_strides, ndim); auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
auto y_offset = elem_to_loc(index, shape, y_strides, ndim); auto y_offset = elem_to_loc(index, shape, y_strides, ndim);
// Do the operation and update the output // Do the operation and update the output
out[index] = out[index] =
static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset]; static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];
} }
We then need to instantiate this template for all floating point types and give We then need to instantiate this template for all floating point types and give
each instantiation a unique host name so we can identify the right kernel for each instantiation a unique host name so we can identify it.
each data type.
.. code-block:: C++ .. code-block:: C++
@ -488,29 +457,21 @@ each data type.
instantiate_axpby(bfloat16, bfloat16_t); instantiate_axpby(bfloat16, bfloat16_t);
instantiate_axpby(complex64, complex64_t); instantiate_axpby(complex64, complex64_t);
This kernel will be compiled into a metal library ``mlx_ext.metallib`` as we The logic to determine the kernel, set the inputs, resolve the grid dimensions,
will see later in :ref:`Building with CMake`. In the following example, we and dispatch to the GPU are contained in :meth:`Axpby::eval_gpu` as shown
assume that the library ``mlx_ext.metallib`` will always be co-located with
the executable/ shared-library calling the :meth:`register_library` function.
The :meth:`register_library` function takes the library's name and potential
path (or in this case, a function that can produce the path of the metal
library) and tries to load that library if it hasn't already been registered
by the relevant static :class:`mlx::core::metal::Device` object. This is why,
it is important to package your C++ library with the metal library. We will
go over this process in more detail later.
The logic to determine the kernel, set the inputs, resolve the grid dimensions
and dispatch it to the GPU are contained in :meth:`Axpby::eval_gpu` as shown
below. below.
.. code-block:: C++ .. code-block:: C++
/** Evaluate primitive on GPU */ /** Evaluate primitive on GPU */
void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) { void Axpby::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
// Prepare inputs // Prepare inputs
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& x = inputs[0]; auto& x = inputs[0];
auto& y = inputs[1]; auto& y = inputs[1];
auto& out = outputs[0];
// Each primitive carries the stream it should execute on // Each primitive carries the stream it should execute on
// and each stream carries its device identifiers // and each stream carries its device identifiers
@ -518,10 +479,10 @@ below.
// We get the needed metal device using the stream // We get the needed metal device using the stream
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
// Allocate output memory // Allocate output memory
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
// Resolve name of kernel (corresponds to axpby.metal) // Resolve name of kernel
std::ostringstream kname; std::ostringstream kname;
kname << "axpby_" << "general_" << type_to_name(out); kname << "axpby_" << "general_" << type_to_name(out);
@ -552,7 +513,7 @@ below.
compute_encoder->setBytes(&alpha_, sizeof(float), 3); compute_encoder->setBytes(&alpha_, sizeof(float), 3);
compute_encoder->setBytes(&beta_, sizeof(float), 4); compute_encoder->setBytes(&beta_, sizeof(float), 4);
// Encode shape, strides and ndim // Encode shape, strides and ndim
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5); compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6); compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7); compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
@ -575,28 +536,25 @@ below.
We can now call the :meth:`axpby` operation on both the CPU and the GPU! We can now call the :meth:`axpby` operation on both the CPU and the GPU!
A few things to note about MLX and metal before moving on. MLX keeps track A few things to note about MLX and Metal before moving on. MLX keeps track of
of the active ``compute_encoder``. We rely on :meth:`d.get_command_encoder` the active ``command_buffer`` and the ``MTLCommandBuffer`` to which it is
to give us the active metal compute command encoder instead of building a associated. We rely on :meth:`d.get_command_encoder` to give us the active
new one and calling :meth:`compute_encoder->end_encoding` at the end. metal compute command encoder instead of building a new one and calling
MLX keeps adding kernels (compute pipelines) to the active command encoder :meth:`compute_encoder->end_encoding` at the end. MLX adds kernels (compute
until some specified limit is hit or the compute encoder needs to be flushed pipelines) to the active command buffer until some specified limit is hit or
for synchronization. MLX also handles enqueuing and committing the associated the command buffer needs to be flushed for synchronization.
command buffers as needed. We suggest taking a deeper dive into
:class:`metal::Device` if you would like to study this routine further.
Primitive Transforms Primitive Transforms
^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^
Now that we have come this far, let's also learn how to add implementations to Next, let's add implementations for transformations in a :class:`Primitive`.
transformations in a :class:`Primitive`. These transformations can be built on These transformations can be built on top of other operations, including the
top of our operations, including the one we just defined now. Which then gives one we just defined:
us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations.
.. code-block:: C++ .. code-block:: C++
/** The Jacobian-vector product. */ /** The Jacobian-vector product. */
array Axpby::jvp( std::vector<array> Axpby::jvp(
const std::vector<array>& primals, const std::vector<array>& primals,
const std::vector<array>& tangents, const std::vector<array>& tangents,
const std::vector<int>& argnums) { const std::vector<int>& argnums) {
@ -611,12 +569,12 @@ us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations.
if (argnums.size() > 1) { if (argnums.size() > 1) {
auto scale = argnums[0] == 0 ? alpha_ : beta_; auto scale = argnums[0] == 0 ? alpha_ : beta_;
auto scale_arr = array(scale, tangents[0].dtype()); auto scale_arr = array(scale, tangents[0].dtype());
return multiply(scale_arr, tangents[0], stream()); return {multiply(scale_arr, tangents[0], stream())};
} }
// If, argnums = {0, 1}, we take contributions from both // If, argnums = {0, 1}, we take contributions from both
// which gives us jvp = tangent_x * alpha + tangent_y * beta // which gives us jvp = tangent_x * alpha + tangent_y * beta
else { else {
return axpby(tangents[0], tangents[1], alpha_, beta_, stream()); return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
} }
} }
@ -625,34 +583,35 @@ us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations.
/** The vector-Jacobian product. */ /** The vector-Jacobian product. */
std::vector<array> Axpby::vjp( std::vector<array> Axpby::vjp(
const std::vector<array>& primals, const std::vector<array>& primals,
const array& cotan, const std::vector<array>& cotangents,
const std::vector<int>& argnums) { const std::vector<int>& argnums,
const std::vector<int>& /* unused */) {
// Reverse mode diff // Reverse mode diff
std::vector<array> vjps; std::vector<array> vjps;
for (auto arg : argnums) { for (auto arg : argnums) {
auto scale = arg == 0 ? alpha_ : beta_; auto scale = arg == 0 ? alpha_ : beta_;
auto scale_arr = array(scale, cotan.dtype()); auto scale_arr = array(scale, cotangents[0].dtype());
vjps.push_back(multiply(scale_arr, cotan, stream())); vjps.push_back(multiply(scale_arr, cotangents[0], stream()));
} }
return vjps; return vjps;
} }
Finally, you need not have a transformation fully defined to start using your Note, a transformation does not need to be fully defined to start using
own :class:`Primitive`. the :class:`Primitive`.
.. code-block:: C++ .. code-block:: C++
/** Vectorize primitive along given axis */ /** Vectorize primitive along given axis */
std::pair<array, int> Axpby::vmap( std::pair<std::vector<array>, std::vector<int>> Axpby::vmap(
const std::vector<array>& inputs, const std::vector<array>& inputs,
const std::vector<int>& axes) { const std::vector<int>& axes) {
throw std::runtime_error("Axpby has no vmap implementation."); throw std::runtime_error("[Axpby] vmap not implemented.");
} }
Building and Binding Building and Binding
-------------------- --------------------
Let's look at the overall directory structure first. Let's look at the overall directory structure first.
| extensions | extensions
| ├── axpby | ├── axpby
@ -666,40 +625,39 @@ Let's look at the overall directory structure first.
| └── setup.py | └── setup.py
* ``extensions/axpby/`` defines the C++ extension library * ``extensions/axpby/`` defines the C++ extension library
* ``extensions/mlx_sample_extensions`` sets out the structure for the * ``extensions/mlx_sample_extensions`` sets out the structure for the
associated python package associated Python package
* ``extensions/bindings.cpp`` provides python bindings for our operation * ``extensions/bindings.cpp`` provides Python bindings for our operation
* ``extensions/CMakeLists.txt`` holds CMake rules to build the library and * ``extensions/CMakeLists.txt`` holds CMake rules to build the library and
python bindings Python bindings
* ``extensions/setup.py`` holds the ``setuptools`` rules to build and install * ``extensions/setup.py`` holds the ``setuptools`` rules to build and install
the python package the Python package
Binding to Python Binding to Python
^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^
We use PyBind11_ to build a Python API for the C++ library. Since bindings for We use nanobind_ to build a Python API for the C++ library. Since bindings for
components such as :class:`mlx.core.array`, :class:`mlx.core.stream`, etc. are components such as :class:`mlx.core.array`, :class:`mlx.core.stream`, etc. are
already provided, adding our :meth:`axpby` is simple! already provided, adding our :meth:`axpby` is simple.
.. code-block:: C++ .. code-block:: C++
PYBIND11_MODULE(mlx_sample_extensions, m) { NB_MODULE(_ext, m) {
m.doc() = "Sample C++ and metal extensions for MLX"; m.doc() = "Sample extension for MLX";
m.def( m.def(
"axpby", "axpby",
&axpby, &axpby,
"x"_a, "x"_a,
"y"_a, "y"_a,
py::pos_only(),
"alpha"_a, "alpha"_a,
"beta"_a, "beta"_a,
py::kw_only(), nb::kw_only(),
"stream"_a = py::none(), "stream"_a = nb::none(),
R"pbdoc( R"(
Scale and sum two vectors element-wise Scale and sum two vectors element-wise
``z = alpha * x + beta * y`` ``z = alpha * x + beta * y``
Follows numpy style broadcasting between ``x`` and ``y`` Follows numpy style broadcasting between ``x`` and ``y``
Inputs are upcasted to floats if needed Inputs are upcasted to floats if needed
@ -711,17 +669,17 @@ already provided, adding our :meth:`axpby` is simple!
Returns: Returns:
array: ``alpha * x + beta * y`` array: ``alpha * x + beta * y``
)pbdoc"); )");
} }
Most of the complexity in the above example comes from additional bells and Most of the complexity in the above example comes from additional bells and
whistles such as the literal names and doc-strings. whistles such as the literal names and doc-strings.
.. warning:: .. warning::
:mod:`mlx.core` needs to be imported before importing :mod:`mlx.core` must be imported before importing
:mod:`mlx_sample_extensions` as defined by the pybind11 module above to :mod:`mlx_sample_extensions` as defined by the nanobind module above to
ensure that the casters for :mod:`mlx.core` components like ensure that the casters for :mod:`mlx.core` components like
:class:`mlx.core.array` are available. :class:`mlx.core.array` are available.
.. _Building with CMake: .. _Building with CMake:
@ -729,8 +687,8 @@ whistles such as the literal names and doc-strings.
Building with CMake Building with CMake
^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^
Building the C++ extension library itself is simple, it only requires that you Building the C++ extension library only requires that you ``find_package(MLX
``find_package(MLX CONFIG)`` and then link it to your library. CONFIG)`` and then link it to your library.
.. code-block:: cmake .. code-block:: cmake
@ -752,12 +710,12 @@ Building the C++ extension library itself is simple, it only requires that you
# Link to mlx # Link to mlx
target_link_libraries(mlx_ext PUBLIC mlx) target_link_libraries(mlx_ext PUBLIC mlx)
We also need to build the attached metal library. For convenience, we provide a We also need to build the attached Metal library. For convenience, we provide a
:meth:`mlx_build_metallib` function that builds a ``.metallib`` target given :meth:`mlx_build_metallib` function that builds a ``.metallib`` target given
sources, headers, destinations, etc. (defined in ``cmake/extension.cmake`` and sources, headers, destinations, etc. (defined in ``cmake/extension.cmake`` and
automatically imported with MLX package). automatically imported with MLX package).
Here is what that looks like in practice! Here is what that looks like in practice:
.. code-block:: cmake .. code-block:: cmake
@ -779,27 +737,29 @@ Here is what that looks like in practice!
endif() endif()
Finally, we build the Pybind11_ bindings Finally, we build the nanobind_ bindings
.. code-block:: cmake .. code-block:: cmake
pybind11_add_module( nanobind_add_module(
mlx_sample_extensions _ext
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp NB_STATIC STABLE_ABI LTO NOMINSIZE
NB_DOMAIN mlx
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
) )
target_link_libraries(mlx_sample_extensions PRIVATE mlx_ext) target_link_libraries(_ext PRIVATE mlx_ext)
if(BUILD_SHARED_LIBS) if(BUILD_SHARED_LIBS)
target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path) target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path)
endif() endif()
Building with ``setuptools`` Building with ``setuptools``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Once we have set out the CMake build rules as described above, we can use the Once we have set out the CMake build rules as described above, we can use the
build utilities defined in :mod:`mlx.extension` for a simple build process. build utilities defined in :mod:`mlx.extension`:
.. code-block:: python .. code-block:: python
from mlx import extension from mlx import extension
from setuptools import setup from setuptools import setup
@ -809,48 +769,50 @@ build utilities defined in :mod:`mlx.extension` for a simple build process.
name="mlx_sample_extensions", name="mlx_sample_extensions",
version="0.0.0", version="0.0.0",
description="Sample C++ and Metal extensions for MLX primitives.", description="Sample C++ and Metal extensions for MLX primitives.",
ext_modules=[extension.CMakeExtension("mlx_sample_extensions")], ext_modules=[extension.CMakeExtension("mlx_sample_extensions._ext")],
cmdclass={"build_ext": extension.CMakeBuild}, cmdclass={"build_ext": extension.CMakeBuild},
packages = ["mlx_sample_extensions"], packages=["mlx_sample_extensions"],
package_dir = {"": "mlx_sample_extensions"}, package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
package_data = {"mlx_sample_extensions" : ["*.so", "*.dylib", "*.metallib"]}, extras_require={"dev":[]},
zip_safe=False, zip_safe=False,
python_requires=">=3.7", python_requires=">=3.8",
) )
.. note:: .. note::
We treat ``extensions/mlx_sample_extensions`` as the package directory We treat ``extensions/mlx_sample_extensions`` as the package directory
even though it only contains a ``__init__.py`` to ensure the following: even though it only contains a ``__init__.py`` to ensure the following:
* :mod:`mlx.core` is always imported before importing :mod:`mlx_sample_extensions`
* The C++ extension library and the metal library are co-located with the python
bindings and copied together if the package is installed
You can build inplace for development using * :mod:`mlx.core` must be imported before importing :mod:`_ext`
* The C++ extension library and the metal library are co-located with the python
bindings and copied together if the package is installed
To build the package, first install the build dependencies with ``pip install
-r requirements.txt``. You can then build inplace for development using
``python setup.py build_ext -j8 --inplace`` (in ``extensions/``) ``python setup.py build_ext -j8 --inplace`` (in ``extensions/``)
This will result in a directory structure as follows: This results in the directory structure:
| extensions | extensions
| ├── mlx_sample_extensions | ├── mlx_sample_extensions
| │ ├── __init__.py | │ ├── __init__.py
| │ ├── libmlx_ext.dylib # C++ extension library | │ ├── libmlx_ext.dylib # C++ extension library
| │ ├── mlx_ext.metallib # Metal library | │ ├── mlx_ext.metallib # Metal library
| │ └── mlx_sample_extensions.cpython-3x-darwin.so # Python Binding | │ └── _ext.cpython-3x-darwin.so # Python Binding
| ... | ...
When you try to install using the command ``python -m pip install .`` When you try to install using the command ``python -m pip install .`` (in
(in ``extensions/``), the package will be installed with the same structure as ``extensions/``), the package will be installed with the same structure as
``extensions/mlx_sample_extensions`` and the C++ and metal library will be ``extensions/mlx_sample_extensions`` and the C++ and Metal library will be
copied along with the python binding since they are specified as ``package_data``. copied along with the Python binding since they are specified as
``package_data``.
Usage Usage
----- -----
After installing the extension as described above, you should be able to simply After installing the extension as described above, you should be able to simply
import the python package and play with it as you would any other MLX operation! import the Python package and play with it as you would any other MLX operation.
Let's looks at a simple script and it's results! Let's look at a simple script and its results:
.. code-block:: python .. code-block:: python
@ -874,12 +836,12 @@ Output:
c correctness: True c correctness: True
Results Results
^^^^^^^^^^^^^^^^ ^^^^^^^
Let's run a quick benchmark and see how our new ``axpby`` operation compares Let's run a quick benchmark and see how our new ``axpby`` operation compares
with the naive :meth:`simple_axpby` we defined at first on the CPU. with the naive :meth:`simple_axpby` we first defined on the CPU.
.. code-block:: python .. code-block:: python
import mlx.core as mx import mlx.core as mx
from mlx_sample_extensions import axpby from mlx_sample_extensions import axpby
@ -898,7 +860,7 @@ with the naive :meth:`simple_axpby` we defined at first on the CPU.
alpha = 4.0 alpha = 4.0
beta = 2.0 beta = 2.0
mx.eval((x, y)) mx.eval(x, y)
def bench(f): def bench(f):
# Warm up # Warm up
@ -919,30 +881,23 @@ with the naive :meth:`simple_axpby` we defined at first on the CPU.
print(f"Simple axpby: {simple_time:.3f} s | Custom axpby: {custom_time:.3f} s") print(f"Simple axpby: {simple_time:.3f} s | Custom axpby: {custom_time:.3f} s")
Results: The results are ``Simple axpby: 0.114 s | Custom axpby: 0.109 s``. We see
modest improvements right away!
.. code-block::
Simple axpby: 0.114 s | Custom axpby: 0.109 s
We see some modest improvements right away!
This operation is now good to be used to build other operations, in This operation is now good to be used to build other operations, in
:class:`mlx.nn.Module` calls, and also as a part of graph transformations like :class:`mlx.nn.Module` calls, and also as a part of graph transformations like
:meth:`grad`! :meth:`grad`.
Scripts Scripts
------- -------
.. admonition:: Download the code .. admonition:: Download the code
The full example code is available in `mlx <code>`_. The full example code is available in `mlx <https://github.com/ml-explore/mlx/tree/main/examples/extensions/>`_.
.. code: `https://github.com/ml-explore/mlx/tree/main/examples/extensions/`_
.. _Accelerate: https://developer.apple.com/documentation/accelerate/blas?language=objc .. _Accelerate: https://developer.apple.com/documentation/accelerate/blas?language=objc
.. _Metal: https://developer.apple.com/documentation/metal?language=objc .. _Metal: https://developer.apple.com/documentation/metal?language=objc
.. _Metal-cpp: https://developer.apple.com/metal/cpp/ .. _Metal-cpp: https://developer.apple.com/metal/cpp/
.. _`Metal Specification`: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf .. _`Metal Specification`: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
.. _`Metal Example`: https://developer.apple.com/documentation/metal/performing_calculations_on_a_gpu?language=objc .. _`Metal Example`: https://developer.apple.com/documentation/metal/performing_calculations_on_a_gpu?language=objc
.. _PyBind11: https://pybind11.readthedocs.io/en/stable/ .. _nanobind: https://nanobind.readthedocs.io/en/latest/

View File

@ -1,6 +1,6 @@
cmake_minimum_required(VERSION 3.27) cmake_minimum_required(VERSION 3.27)
project(mlx_sample_extensions LANGUAGES CXX) project(_ext LANGUAGES CXX)
# ----------------------------- Setup ----------------------------- # ----------------------------- Setup -----------------------------
set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD 17)
@ -11,8 +11,12 @@ option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON)
# ----------------------------- Dependencies ----------------------------- # ----------------------------- Dependencies -----------------------------
find_package(MLX CONFIG REQUIRED) find_package(MLX CONFIG REQUIRED)
find_package(Python COMPONENTS Interpreter Development) find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
find_package(pybind11 CONFIG REQUIRED) execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR)
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
find_package(nanobind CONFIG REQUIRED)
# ----------------------------- Extensions ----------------------------- # ----------------------------- Extensions -----------------------------
@ -38,7 +42,6 @@ target_link_libraries(mlx_ext PUBLIC mlx)
# Build metallib # Build metallib
if(MLX_BUILD_METAL) if(MLX_BUILD_METAL)
mlx_build_metallib( mlx_build_metallib(
TARGET mlx_ext_metallib TARGET mlx_ext_metallib
TITLE mlx_ext TITLE mlx_ext
@ -54,13 +57,15 @@ if(MLX_BUILD_METAL)
endif() endif()
# ----------------------------- Pybind ----------------------------- # ----------------------------- Python Bindings -----------------------------
pybind11_add_module( nanobind_add_module(
mlx_sample_extensions _ext
NB_STATIC STABLE_ABI LTO NOMINSIZE
NB_DOMAIN mlx
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp ${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
) )
target_link_libraries(mlx_sample_extensions PRIVATE mlx_ext) target_link_libraries(_ext PRIVATE mlx_ext)
if(BUILD_SHARED_LIBS) if(BUILD_SHARED_LIBS)
target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path) target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path)
endif() endif()

View File

@ -0,0 +1,18 @@
## Build the extensions
```
pip install -e .
```
For faster builds during development, you can also pre-install the requirements:
```
pip install -r requirements.txt
```
And then run:
```
python setup.py build_ext -j8 --inplace
```

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <cassert> #include <cassert>
#include <iostream> #include <iostream>
@ -43,7 +43,7 @@ array axpby(
auto promoted_dtype = promote_types(x.dtype(), y.dtype()); auto promoted_dtype = promote_types(x.dtype(), y.dtype());
// Upcast to float32 for non-floating point inputs x and y // Upcast to float32 for non-floating point inputs x and y
auto out_dtype = is_floating_point(promoted_dtype) auto out_dtype = issubdtype(promoted_dtype, float32)
? promoted_dtype ? promoted_dtype
: promote_types(promoted_dtype, float32); : promote_types(promoted_dtype, float32);
@ -106,12 +106,12 @@ void axpby_impl(
/** Fall back implementation for evaluation on CPU */ /** Fall back implementation for evaluation on CPU */
void Axpby::eval( void Axpby::eval(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& out_arr) { std::vector<array>& outputs) {
auto out = out_arr[0];
// Check the inputs (registered in the op while constructing the out array) // Check the inputs (registered in the op while constructing the out array)
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& x = inputs[0]; auto& x = inputs[0];
auto& y = inputs[1]; auto& y = inputs[1];
auto& out = outputs[0];
// Dispatch to the correct dtype // Dispatch to the correct dtype
if (out.dtype() == float32) { if (out.dtype() == float32) {
@ -150,11 +150,7 @@ void axpby_impl_accelerate(
// The data in the output array is allocated to match the strides in y // The data in the output array is allocated to match the strides in y
// such that x, y, and out are contiguous in the same mode and // such that x, y, and out are contiguous in the same mode and
// no transposition is needed // no transposition is needed
out.set_data( out.set_data(allocator::malloc_or_wait(out.nbytes()));
allocator::malloc_or_wait(y.data_size() * out.itemsize()),
y.data_size(),
y.strides(),
y.flags());
// We then copy over the elements using the contiguous vector specialization // We then copy over the elements using the contiguous vector specialization
copy_inplace(y, out, CopyType::Vector); copy_inplace(y, out, CopyType::Vector);
@ -180,11 +176,11 @@ void axpby_impl_accelerate(
/** Evaluate primitive on CPU using accelerate specializations */ /** Evaluate primitive on CPU using accelerate specializations */
void Axpby::eval_cpu( void Axpby::eval_cpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outarr) { std::vector<array>& outputs) {
auto out = outarr[0];
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& x = inputs[0]; auto& x = inputs[0];
auto& y = inputs[1]; auto& y = inputs[1];
auto& out = outputs[0];
// Accelerate specialization for contiguous single precision float arrays // Accelerate specialization for contiguous single precision float arrays
if (out.dtype() == float32 && if (out.dtype() == float32 &&
@ -195,7 +191,7 @@ void Axpby::eval_cpu(
} }
// Fall back to common backend if specializations are not available // Fall back to common backend if specializations are not available
eval(inputs, outarr); eval(inputs, outputs);
} }
#else // Accelerate not available #else // Accelerate not available
@ -203,8 +199,8 @@ void Axpby::eval_cpu(
/** Evaluate primitive on CPU falling back to common backend */ /** Evaluate primitive on CPU falling back to common backend */
void Axpby::eval_cpu( void Axpby::eval_cpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& out) { const std::vector<array>& outputs) {
eval(inputs, out); eval(inputs, outputs);
} }
#endif #endif
@ -218,12 +214,12 @@ void Axpby::eval_cpu(
/** Evaluate primitive on GPU */ /** Evaluate primitive on GPU */
void Axpby::eval_gpu( void Axpby::eval_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outarr) { std::vector<array>& outputs) {
// Prepare inputs // Prepare inputs
auto out = outarr[0];
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& x = inputs[0]; auto& x = inputs[0];
auto& y = inputs[1]; auto& y = inputs[1];
auto& out = outputs[0];
// Each primitive carries the stream it should execute on // Each primitive carries the stream it should execute on
// and each stream carries its device identifiers // and each stream carries its device identifiers
@ -372,4 +368,4 @@ bool Axpby::is_equivalent(const Primitive& other) const {
return alpha_ == r_other.alpha_ && beta_ == r_other.beta_; return alpha_ == r_other.alpha_ && beta_ == r_other.beta_;
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -42,9 +42,9 @@ class Axpby : public Primitive {
* To avoid unnecessary allocations, the evaluation function * To avoid unnecessary allocations, the evaluation function
* is responsible for allocating space for the array. * is responsible for allocating space for the array.
*/ */
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& out) void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override; override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& out) void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override; override;
/** The Jacobian-vector product. */ /** The Jacobian-vector product. */
@ -83,7 +83,7 @@ class Axpby : public Primitive {
float beta_; float beta_;
/** Fall back implementation for evaluation on CPU */ /** Fall back implementation for evaluation on CPU */
void eval(const std::vector<array>& inputs, std::vector<array>& out); void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
}; };
} // namespace mlx::core } // namespace mlx::core

View File

@ -1,31 +1,31 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <pybind11/pybind11.h> #include <nanobind/nanobind.h>
#include <pybind11/stl.h> #include <nanobind/stl/variant.h>
#include "axpby/axpby.h" #include "axpby/axpby.h"
namespace py = pybind11; namespace nb = nanobind;
using namespace py::literals; using namespace nb::literals;
using namespace mlx::core; using namespace mlx::core;
PYBIND11_MODULE(mlx_sample_extensions, m) { NB_MODULE(_ext, m) {
m.doc() = "Sample C++ and metal extensions for MLX"; m.doc() = "Sample extension for MLX";
m.def( m.def(
"axpby", "axpby",
&axpby, &axpby,
"x"_a, "x"_a,
"y"_a, "y"_a,
py::pos_only(),
"alpha"_a, "alpha"_a,
"beta"_a, "beta"_a,
py::kw_only(), nb::kw_only(),
"stream"_a = py::none(), "stream"_a = nb::none(),
R"pbdoc( R"(
Scale and sum two vectors element-wise Scale and sum two vectors element-wise
``z = alpha * x + beta * y`` ``z = alpha * x + beta * y``
Follows numpy style broadcasting between ``x`` and ``y`` Follows numpy style broadcasting between ``x`` and ``y``
Inputs are upcasted to floats if needed Inputs are upcasted to floats if needed
@ -37,5 +37,5 @@ PYBIND11_MODULE(mlx_sample_extensions, m) {
Returns: Returns:
array: ``alpha * x + beta * y`` array: ``alpha * x + beta * y``
)pbdoc"); )");
} }

View File

@ -1,3 +1,8 @@
[build-system] [build-system]
requires = ["setuptools>=42", "pybind11>=2.10", "cmake>=3.24", "mlx @ git+https://github.com/mlx-explore/mlx@main"] requires = [
build-backend = "setuptools.build_meta" "setuptools>=42",
"cmake>=3.24",
"mlx>=0.9.0",
"nanobind@git+https://github.com/wjakob/nanobind.git#egg=4148debcf91f5ccab0c3b8d67b5c3cabd61f407f",
]
build-backend = "setuptools.build_meta"

View File

@ -0,0 +1,4 @@
setuptools>=42
cmake>=3.24
mlx>=0.9.0
nanobind@git+https://github.com/wjakob/nanobind.git#egg=4148debcf91f5ccab0c3b8d67b5c3cabd61f407f

View File

@ -1,4 +1,4 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023-2024 Apple Inc.
from setuptools import setup from setuptools import setup
@ -9,11 +9,11 @@ if __name__ == "__main__":
name="mlx_sample_extensions", name="mlx_sample_extensions",
version="0.0.0", version="0.0.0",
description="Sample C++ and Metal extensions for MLX primitives.", description="Sample C++ and Metal extensions for MLX primitives.",
ext_modules=[extension.CMakeExtension("mlx_sample_extensions")], ext_modules=[extension.CMakeExtension("mlx_sample_extensions._ext")],
cmdclass={"build_ext": extension.CMakeBuild}, cmdclass={"build_ext": extension.CMakeBuild},
packages=["mlx_sample_extensions"], packages=["mlx_sample_extensions"],
package_dir={"": "."},
package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]}, package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
extras_require={"dev": []},
zip_safe=False, zip_safe=False,
python_requires=">=3.8", python_requires=">=3.8",
) )