mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Extensions (#962)
* start to fix extensions * mostly fixed extensions * fix extension build * couple more nits
This commit is contained in:
parent
42afe27e12
commit
b63ef10a7f
@ -1,24 +1,16 @@
|
|||||||
Developer Documentation
|
Developer Documentation
|
||||||
=======================
|
=======================
|
||||||
|
|
||||||
MLX provides a open and flexible backend to which users may add operations
|
You can extend MLX with custom operations on the CPU or GPU. This guide
|
||||||
and specialized implementations without much hassle. While the library supplies
|
explains how to do that with a simple example.
|
||||||
efficient operations that can be used and composed for any number of
|
|
||||||
applications, there may arise cases where new functionalities or highly
|
|
||||||
optimized implementations are needed. For such cases, you may design and
|
|
||||||
implement your own operations that link to and build on top of :mod:`mlx.core`.
|
|
||||||
We will introduce the inner-workings of MLX and go over a simple example to
|
|
||||||
learn the steps involved in adding new operations to MLX with your own CPU
|
|
||||||
and GPU implementations.
|
|
||||||
|
|
||||||
Introducing the Example
|
Introducing the Example
|
||||||
-----------------------
|
-----------------------
|
||||||
|
|
||||||
Let's say that you would like an operation that takes in two arrays,
|
Let's say you would like an operation that takes in two arrays, ``x`` and
|
||||||
``x`` and ``y``, scales them both by some coefficients ``alpha`` and ``beta``
|
``y``, scales them both by coefficients ``alpha`` and ``beta`` respectively,
|
||||||
respectively, and then adds them together to get the result
|
and then adds them together to get the result ``z = alpha * x + beta * y``.
|
||||||
``z = alpha * x + beta * y``. Well, you can very easily do that by just
|
You can do that in MLX directly:
|
||||||
writing out a function as follows:
|
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
@ -27,44 +19,35 @@ writing out a function as follows:
|
|||||||
def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
|
def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
|
||||||
return alpha * x + beta * y
|
return alpha * x + beta * y
|
||||||
|
|
||||||
This function performs that operation while leaving the implementations and
|
This function performs that operation while leaving the implementation and
|
||||||
differentiation to MLX.
|
function transformations to MLX.
|
||||||
|
|
||||||
However, you work with vector math libraries often and realize that the
|
However you may need to customize the underlying implementation, perhaps to
|
||||||
``axpby`` routine defines the same operation ``Y = (alpha * X) + (beta * Y)``.
|
make it faster or for custom differentiation. In this tutorial we will go
|
||||||
You would really like the part of your applications that does this operation
|
through adding custom extensions. It will cover:
|
||||||
on the CPU to be very fast - so you decide that you want it to rely on the
|
|
||||||
``axpby`` routine provided by the Accelerate_ framework. Continuing to impose
|
|
||||||
our assumptions on to you, let's also assume that you want to learn how to add
|
|
||||||
your own implementation for the gradients of your new operation while going
|
|
||||||
over the ins-and-outs of the MLX framework.
|
|
||||||
|
|
||||||
Well, what a coincidence! You are in the right place. Over the course of this
|
* The structure of the MLX library.
|
||||||
example, we will learn:
|
* Implementing a CPU operation that redirects to Accelerate_ when appropriate.
|
||||||
|
* Implementing a GPU operation using metal.
|
||||||
* The structure of the MLX library from the frontend API to the backend implementations.
|
* Adding the ``vjp`` and ``jvp`` function transformation.
|
||||||
* How to implement your own CPU backend that redirects to Accelerate_ when appropriate (and a fallback if needed).
|
* Building a custom extension and binding it to python.
|
||||||
* How to implement your own GPU implementation using metal.
|
|
||||||
* How to add your own ``vjp`` and ``jvp``.
|
|
||||||
* How to build your implementations, link them to MLX, and bind them to python.
|
|
||||||
|
|
||||||
Operations and Primitives
|
Operations and Primitives
|
||||||
-------------------------
|
-------------------------
|
||||||
|
|
||||||
In one sentence, operations in MLX build the computation graph, and primitives
|
Operations in MLX build the computation graph. Primitives provide the rules for
|
||||||
provide the rules for evaluation and transformations of said graph. Let's start
|
evaluating and transforming the graph. Let's start by discussing operations in
|
||||||
by discussing operations in more detail.
|
more detail.
|
||||||
|
|
||||||
Operations
|
Operations
|
||||||
^^^^^^^^^^^
|
^^^^^^^^^^^
|
||||||
|
|
||||||
Operations are the frontend functions that operate on arrays. They are defined
|
Operations are the front-end functions that operate on arrays. They are defined
|
||||||
in the C++ API (:ref:`cpp_ops`) and then we provide bindings to these
|
in the C++ API (:ref:`cpp_ops`), and the Python API (:ref:`ops`) binds them.
|
||||||
operations in the Python API (:ref:`ops`).
|
|
||||||
|
|
||||||
We would like an operation, :meth:`axpby` that takes in two arrays ``x`` and ``y``,
|
We would like an operation, :meth:`axpby` that takes in two arrays ``x`` and
|
||||||
and two scalars, ``alpha`` and ``beta``. This is how we would define it in the
|
``y``, and two scalars, ``alpha`` and ``beta``. This is how to define it in
|
||||||
C++ API:
|
C++:
|
||||||
|
|
||||||
.. code-block:: C++
|
.. code-block:: C++
|
||||||
|
|
||||||
@ -83,10 +66,7 @@ C++ API:
|
|||||||
StreamOrDevice s = {} // Stream on which to schedule the operation
|
StreamOrDevice s = {} // Stream on which to schedule the operation
|
||||||
);
|
);
|
||||||
|
|
||||||
|
The simplest way to this operation is in terms of existing operations:
|
||||||
This operation itself can call other operations within it if needed. So, the
|
|
||||||
simplest way to go about implementing this operation would be do so in terms
|
|
||||||
of existing operations.
|
|
||||||
|
|
||||||
.. code-block:: C++
|
.. code-block:: C++
|
||||||
|
|
||||||
@ -100,25 +80,23 @@ of existing operations.
|
|||||||
// Scale x and y on the provided stream
|
// Scale x and y on the provided stream
|
||||||
auto ax = multiply(array(alpha), x, s);
|
auto ax = multiply(array(alpha), x, s);
|
||||||
auto by = multiply(array(beta), y, s);
|
auto by = multiply(array(beta), y, s);
|
||||||
|
|
||||||
// Add and return
|
// Add and return
|
||||||
return add(ax, by, s);
|
return add(ax, by, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
However, as we discussed earlier, this is not our goal. The operations themselves
|
The operations themselves do not contain the implementations that act on the
|
||||||
do not contain the implementations that act on the data, nor do they contain the
|
data, nor do they contain the rules of transformations. Rather, they are an
|
||||||
rules of transformations. Rather, they are an easy to use interface that build
|
easy to use interface that use :class:`Primitive` building blocks.
|
||||||
on top of the building blocks we call :class:`Primitive`.
|
|
||||||
|
|
||||||
Primitives
|
Primitives
|
||||||
^^^^^^^^^^^
|
^^^^^^^^^^^
|
||||||
|
|
||||||
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
|
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
|
||||||
defines how to create an output given a set of input :class:`array` . Further,
|
defines how to create outputs arrays given a input arrays. Further, a
|
||||||
a :class:`Primitive` is a class that contains rules on how it is evaluated
|
:class:`Primitive` has methods to run on the CPU or GPU and for function
|
||||||
on the CPU or GPU, and how it acts under transformations such as ``vjp`` and
|
transformations such as ``vjp`` and ``jvp``. Lets go back to our example to be
|
||||||
``jvp``. These words on their own can be a bit abstract, so lets take a step
|
more concrete:
|
||||||
back and go to our example to give ourselves a more concrete image.
|
|
||||||
|
|
||||||
.. code-block:: C++
|
.. code-block:: C++
|
||||||
|
|
||||||
@ -134,11 +112,15 @@ back and go to our example to give ourselves a more concrete image.
|
|||||||
* To avoid unnecessary allocations, the evaluation function
|
* To avoid unnecessary allocations, the evaluation function
|
||||||
* is responsible for allocating space for the array.
|
* is responsible for allocating space for the array.
|
||||||
*/
|
*/
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) override;
|
||||||
|
void eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) override;
|
||||||
|
|
||||||
/** The Jacobian-vector product. */
|
/** The Jacobian-vector product. */
|
||||||
array jvp(
|
std::vector<array> jvp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<array>& primals,
|
||||||
const std::vector<array>& tangents,
|
const std::vector<array>& tangents,
|
||||||
const std::vector<int>& argnums) override;
|
const std::vector<int>& argnums) override;
|
||||||
@ -147,7 +129,8 @@ back and go to our example to give ourselves a more concrete image.
|
|||||||
std::vector<array> vjp(
|
std::vector<array> vjp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<array>& primals,
|
||||||
const array& cotan,
|
const array& cotan,
|
||||||
const std::vector<int>& argnums) override;
|
const std::vector<int>& argnums,
|
||||||
|
const std::vector<array>& outputs) override;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The primitive must know how to vectorize itself across
|
* The primitive must know how to vectorize itself across
|
||||||
@ -155,7 +138,7 @@ back and go to our example to give ourselves a more concrete image.
|
|||||||
* representing the vectorized computation and the axis which
|
* representing the vectorized computation and the axis which
|
||||||
* corresponds to the output vectorized dimension.
|
* corresponds to the output vectorized dimension.
|
||||||
*/
|
*/
|
||||||
std::pair<array, int> vmap(
|
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) override;
|
const std::vector<int>& axes) override;
|
||||||
|
|
||||||
@ -175,22 +158,22 @@ back and go to our example to give ourselves a more concrete image.
|
|||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
};
|
};
|
||||||
|
|
||||||
The :class:`Axpby` class derives from the base :class:`Primitive` class and
|
The :class:`Axpby` class derives from the base :class:`Primitive` class. The
|
||||||
follows the above demonstrated interface. :class:`Axpby` treats ``alpha`` and
|
:class:`Axpby` treats ``alpha`` and ``beta`` as parameters. It then provides
|
||||||
``beta`` as parameters. It then provides implementations of how the array ``out``
|
implementations of how the output array is produced given the inputs through
|
||||||
is produced given ``inputs`` through :meth:`Axpby::eval_cpu` and
|
:meth:`Axpby::eval_cpu` and :meth:`Axpby::eval_gpu`. It also provides rules
|
||||||
:meth:`Axpby::eval_gpu`. Further, it provides rules of transformations in
|
of transformations in :meth:`Axpby::jvp`, :meth:`Axpby::vjp`, and
|
||||||
:meth:`Axpby::jvp`, :meth:`Axpby::vjp`, and :meth:`Axpby::vmap`.
|
:meth:`Axpby::vmap`.
|
||||||
|
|
||||||
Using the Primitives
|
Using the Primitive
|
||||||
^^^^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
Operations can use this :class:`Primitive` to add a new :class:`array` to
|
Operations can use this :class:`Primitive` to add a new :class:`array` to the
|
||||||
the computation graph. An :class:`array` can be constructed by providing its
|
computation graph. An :class:`array` can be constructed by providing its data
|
||||||
data type, shape, the :class:`Primitive` that computes it, and the
|
type, shape, the :class:`Primitive` that computes it, and the :class:`array`
|
||||||
:class:`array` inputs that are passed to the primitive.
|
inputs that are passed to the primitive.
|
||||||
|
|
||||||
Let's re-implement our operation now in terms of our :class:`Axpby` primitive.
|
Let's reimplement our operation now in terms of our :class:`Axpby` primitive.
|
||||||
|
|
||||||
.. code-block:: C++
|
.. code-block:: C++
|
||||||
|
|
||||||
@ -238,27 +221,26 @@ This operation now handles the following:
|
|||||||
Implementing the Primitive
|
Implementing the Primitive
|
||||||
--------------------------
|
--------------------------
|
||||||
|
|
||||||
No computation happens when we call the operation alone. In effect, the
|
No computation happens when we call the operation alone. The operation only
|
||||||
operation only builds the computation graph. When we evaluate the output
|
builds the computation graph. When we evaluate the output array, MLX schedules
|
||||||
array, MLX schedules the execution of the computation graph, and calls
|
the execution of the computation graph, and calls :meth:`Axpby::eval_cpu` or
|
||||||
:meth:`Axpby::eval_cpu` or :meth:`Axpby::eval_gpu` depending on the
|
:meth:`Axpby::eval_gpu` depending on the stream/device specified by the user.
|
||||||
stream/device specified by the user.
|
|
||||||
|
|
||||||
.. warning::
|
.. warning::
|
||||||
When :meth:`Primitive::eval_cpu` or :meth:`Primitive::eval_gpu` are called,
|
When :meth:`Primitive::eval_cpu` or :meth:`Primitive::eval_gpu` are called,
|
||||||
no memory has been allocated for the output array. It falls on the implementation
|
no memory has been allocated for the output array. It falls on the implementation
|
||||||
of these functions to allocate memory as needed
|
of these functions to allocate memory as needed.
|
||||||
|
|
||||||
Implementing the CPU Backend
|
Implementing the CPU Back-end
|
||||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
Let's start by trying to implement a naive and generic version of
|
Let's start by implementing a naive and generic version of
|
||||||
:meth:`Axpby::eval_cpu`. We declared this as a private member function of
|
:meth:`Axpby::eval_cpu`. We declared this as a private member function of
|
||||||
:class:`Axpby` earlier called :meth:`Axpby::eval`.
|
:class:`Axpby` earlier called :meth:`Axpby::eval`.
|
||||||
|
|
||||||
Our naive method will go over each element of the output array, find the
|
Our naive method will go over each element of the output array, find the
|
||||||
corresponding input elements of ``x`` and ``y`` and perform the operation
|
corresponding input elements of ``x`` and ``y`` and perform the operation
|
||||||
pointwise. This is captured in the templated function :meth:`axpby_impl`.
|
point-wise. This is captured in the templated function :meth:`axpby_impl`.
|
||||||
|
|
||||||
.. code-block:: C++
|
.. code-block:: C++
|
||||||
|
|
||||||
@ -296,19 +278,19 @@ pointwise. This is captured in the templated function :meth:`axpby_impl`.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Now, we would like our implementation to be able to do this pointwise operation
|
Our implementation should work for all incoming floating point arrays.
|
||||||
for all incoming floating point arrays. Accordingly, we add dispatches for
|
Accordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and
|
||||||
``float32``, ``float16``, ``bfloat16`` and ``complex64``. We throw an error
|
``complex64``. We throw an error if we encounter an unexpected type.
|
||||||
if we encounter an unexpected type.
|
|
||||||
|
|
||||||
.. code-block:: C++
|
.. code-block:: C++
|
||||||
|
|
||||||
/** Fall back implementation for evaluation on CPU */
|
/** Fall back implementation for evaluation on CPU */
|
||||||
void Axpby::eval(const std::vector<array>& inputs, array& out) {
|
void Axpby::eval(
|
||||||
// Check the inputs (registered in the op while constructing the out array)
|
const std::vector<array>& inputs,
|
||||||
assert(inputs.size() == 2);
|
const std::vector<array>& outputs) {
|
||||||
auto& x = inputs[0];
|
auto& x = inputs[0];
|
||||||
auto& y = inputs[1];
|
auto& y = inputs[1];
|
||||||
|
auto& out = outputs[0];
|
||||||
|
|
||||||
// Dispatch to the correct dtype
|
// Dispatch to the correct dtype
|
||||||
if (out.dtype() == float32) {
|
if (out.dtype() == float32) {
|
||||||
@ -321,28 +303,26 @@ if we encounter an unexpected type.
|
|||||||
return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
|
return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"Axpby is only supported for floating point types.");
|
"[Axpby] Only supports floating point types.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
We have a fallback implementation! Now, to do what we are really here to do.
|
This is good as a fallback implementation. We can use the ``axpby`` routine
|
||||||
Remember we wanted to use the ``axpby`` routine provided by the Accelerate_
|
provided by the Accelerate_ framework for a faster implementation in certain
|
||||||
framework? Well, there are 3 complications to keep in mind:
|
cases:
|
||||||
|
|
||||||
#. Accelerate does not provide implementations of ``axpby`` for half precision
|
#. Accelerate does not provide implementations of ``axpby`` for half precision
|
||||||
floats. We can only direct to it for ``float32`` types
|
floats. We can only use it for ``float32`` types.
|
||||||
#. Accelerate assumes the inputs ``x`` and ``y`` are contiguous and all elements
|
#. Accelerate assumes the inputs ``x`` and ``y`` are contiguous and all
|
||||||
have fixed strides between them. Possibly due to broadcasts and transposes,
|
elements have fixed strides between them. We only direct to Accelerate
|
||||||
we aren't guaranteed that the inputs fit this requirement. We can
|
if both ``x`` and ``y`` are row contiguous or column contiguous.
|
||||||
only direct to Accelerate if both ``x`` and ``y`` are row contiguous or
|
#. Accelerate performs the routine ``Y = (alpha * X) + (beta * Y)`` in-place.
|
||||||
column contiguous.
|
MLX expects to write the output to a new array. We must copy the elements
|
||||||
#. Accelerate performs the routine ``Y = (alpha * X) + (beta * Y)`` inplace.
|
of ``y`` into the output and use that as an input to ``axpby``.
|
||||||
MLX expects to write out the answer to a new array. We must copy the elements
|
|
||||||
of ``y`` into the output array and use that as an input to ``axpby``
|
|
||||||
|
|
||||||
Let's write out an implementation that uses Accelerate in the right conditions.
|
Let's write an implementation that uses Accelerate in the right conditions.
|
||||||
It must simply allocate data for the output, copy elements of ``y`` into it,
|
It allocates data for the output, copies ``y`` into it, and then calls the
|
||||||
and then call the :meth:`catlas_saxpby` from accelerate.
|
:func:`catlas_saxpby` from accelerate.
|
||||||
|
|
||||||
.. code-block:: C++
|
.. code-block:: C++
|
||||||
|
|
||||||
@ -356,17 +336,7 @@ and then call the :meth:`catlas_saxpby` from accelerate.
|
|||||||
// Accelerate library provides catlas_saxpby which does
|
// Accelerate library provides catlas_saxpby which does
|
||||||
// Y = (alpha * X) + (beta * Y) in place
|
// Y = (alpha * X) + (beta * Y) in place
|
||||||
// To use it, we first copy the data in y over to the output array
|
// To use it, we first copy the data in y over to the output array
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
// This specialization requires both x and y be contiguous in the same mode
|
|
||||||
// i.e: corresponding linear indices in both point to corresponding elements
|
|
||||||
// The data in the output array is allocated to match the strides in y
|
|
||||||
// such that x, y, and out are contiguous in the same mode and
|
|
||||||
// no transposition is needed
|
|
||||||
out.set_data(
|
|
||||||
allocator::malloc_or_wait(y.data_size() * out.itemsize()),
|
|
||||||
y.data_size(),
|
|
||||||
y.strides(),
|
|
||||||
y.flags());
|
|
||||||
|
|
||||||
// We then copy over the elements using the contiguous vector specialization
|
// We then copy over the elements using the contiguous vector specialization
|
||||||
copy_inplace(y, out, CopyType::Vector);
|
copy_inplace(y, out, CopyType::Vector);
|
||||||
@ -389,18 +359,20 @@ and then call the :meth:`catlas_saxpby` from accelerate.
|
|||||||
/* INCY = */ 1);
|
/* INCY = */ 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
Great! But what about the inputs that do not fit the criteria for accelerate?
|
For inputs that do not fit the criteria for accelerate, we fall back to
|
||||||
Luckily, we can always just direct back to :meth:`Axpby::eval`.
|
:meth:`Axpby::eval`. With this in mind, let's finish our
|
||||||
|
:meth:`Axpby::eval_cpu`.
|
||||||
With this in mind, lets finally implement our :meth:`Axpby::eval_cpu`.
|
|
||||||
|
|
||||||
.. code-block:: C++
|
.. code-block:: C++
|
||||||
|
|
||||||
/** Evaluate primitive on CPU using accelerate specializations */
|
/** Evaluate primitive on CPU using accelerate specializations */
|
||||||
void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Axpby::eval_cpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<array>& outputs) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& x = inputs[0];
|
auto& x = inputs[0];
|
||||||
auto& y = inputs[1];
|
auto& y = inputs[1];
|
||||||
|
auto& out = outputs[0];
|
||||||
|
|
||||||
// Accelerate specialization for contiguous single precision float arrays
|
// Accelerate specialization for contiguous single precision float arrays
|
||||||
if (out.dtype() == float32 &&
|
if (out.dtype() == float32 &&
|
||||||
@ -410,35 +382,33 @@ With this in mind, lets finally implement our :meth:`Axpby::eval_cpu`.
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fall back to common backend if specializations are not available
|
// Fall back to common back-end if specializations are not available
|
||||||
eval(inputs, out);
|
eval(inputs, outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
We have now hit a milestone! Just this much is enough to run the operation
|
Just this much is enough to run the operation :meth:`axpby` on a CPU stream! If
|
||||||
:meth:`axpby` on a CPU stream!
|
you do not plan on running the operation on the GPU or using transforms on
|
||||||
|
computation graphs that contain :class:`Axpby`, you can stop implementing the
|
||||||
|
primitive here and enjoy the speed-ups you get from the Accelerate library.
|
||||||
|
|
||||||
If you do not plan on running the operation on the GPU or using transforms on
|
Implementing the GPU Back-end
|
||||||
computation graphs that contain :class:`Axpby`, you can stop implementing the
|
|
||||||
primitive here and enjoy the speed-ups you get from the Accelerate library.
|
|
||||||
|
|
||||||
Implementing the GPU Backend
|
|
||||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
Apple silicon devices address their GPUs using the Metal_ shading language, and
|
Apple silicon devices address their GPUs using the Metal_ shading language, and
|
||||||
all GPU kernels in MLX are written using metal.
|
GPU kernels in MLX are written using Metal.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
Here are some helpful resources if you are new to metal!
|
Here are some helpful resources if you are new to Metal:
|
||||||
|
|
||||||
* A walkthrough of the metal compute pipeline: `Metal Example`_
|
* A walkthrough of the metal compute pipeline: `Metal Example`_
|
||||||
* Documentation for metal shading language: `Metal Specification`_
|
* Documentation for metal shading language: `Metal Specification`_
|
||||||
* Using metal from C++: `Metal-cpp`_
|
* Using metal from C++: `Metal-cpp`_
|
||||||
|
|
||||||
Let's keep the GPU algorithm simple. We will launch exactly as many threads
|
Let's keep the GPU kernel simple. We will launch exactly as many threads as
|
||||||
as there are elements in the output. Each thread will pick the element it needs
|
there are elements in the output. Each thread will pick the element it needs
|
||||||
from ``x`` and ``y``, do the pointwise operation, and then update its assigned
|
from ``x`` and ``y``, do the point-wise operation, and update its assigned
|
||||||
element in the output.
|
element in the output.
|
||||||
|
|
||||||
.. code-block:: C++
|
.. code-block:: C++
|
||||||
|
|
||||||
@ -457,15 +427,14 @@ element in the output.
|
|||||||
// Convert linear indices to offsets in array
|
// Convert linear indices to offsets in array
|
||||||
auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
|
auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
|
||||||
auto y_offset = elem_to_loc(index, shape, y_strides, ndim);
|
auto y_offset = elem_to_loc(index, shape, y_strides, ndim);
|
||||||
|
|
||||||
// Do the operation and update the output
|
// Do the operation and update the output
|
||||||
out[index] =
|
out[index] =
|
||||||
static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];
|
static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];
|
||||||
}
|
}
|
||||||
|
|
||||||
We then need to instantiate this template for all floating point types and give
|
We then need to instantiate this template for all floating point types and give
|
||||||
each instantiation a unique host name so we can identify the right kernel for
|
each instantiation a unique host name so we can identify it.
|
||||||
each data type.
|
|
||||||
|
|
||||||
.. code-block:: C++
|
.. code-block:: C++
|
||||||
|
|
||||||
@ -488,29 +457,21 @@ each data type.
|
|||||||
instantiate_axpby(bfloat16, bfloat16_t);
|
instantiate_axpby(bfloat16, bfloat16_t);
|
||||||
instantiate_axpby(complex64, complex64_t);
|
instantiate_axpby(complex64, complex64_t);
|
||||||
|
|
||||||
This kernel will be compiled into a metal library ``mlx_ext.metallib`` as we
|
The logic to determine the kernel, set the inputs, resolve the grid dimensions,
|
||||||
will see later in :ref:`Building with CMake`. In the following example, we
|
and dispatch to the GPU are contained in :meth:`Axpby::eval_gpu` as shown
|
||||||
assume that the library ``mlx_ext.metallib`` will always be co-located with
|
|
||||||
the executable/ shared-library calling the :meth:`register_library` function.
|
|
||||||
The :meth:`register_library` function takes the library's name and potential
|
|
||||||
path (or in this case, a function that can produce the path of the metal
|
|
||||||
library) and tries to load that library if it hasn't already been registered
|
|
||||||
by the relevant static :class:`mlx::core::metal::Device` object. This is why,
|
|
||||||
it is important to package your C++ library with the metal library. We will
|
|
||||||
go over this process in more detail later.
|
|
||||||
|
|
||||||
The logic to determine the kernel, set the inputs, resolve the grid dimensions
|
|
||||||
and dispatch it to the GPU are contained in :meth:`Axpby::eval_gpu` as shown
|
|
||||||
below.
|
below.
|
||||||
|
|
||||||
.. code-block:: C++
|
.. code-block:: C++
|
||||||
|
|
||||||
/** Evaluate primitive on GPU */
|
/** Evaluate primitive on GPU */
|
||||||
void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Axpby::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
// Prepare inputs
|
// Prepare inputs
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& x = inputs[0];
|
auto& x = inputs[0];
|
||||||
auto& y = inputs[1];
|
auto& y = inputs[1];
|
||||||
|
auto& out = outputs[0];
|
||||||
|
|
||||||
// Each primitive carries the stream it should execute on
|
// Each primitive carries the stream it should execute on
|
||||||
// and each stream carries its device identifiers
|
// and each stream carries its device identifiers
|
||||||
@ -518,10 +479,10 @@ below.
|
|||||||
// We get the needed metal device using the stream
|
// We get the needed metal device using the stream
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
// Allocate output memory
|
// Allocate output memory
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
// Resolve name of kernel (corresponds to axpby.metal)
|
// Resolve name of kernel
|
||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
kname << "axpby_" << "general_" << type_to_name(out);
|
kname << "axpby_" << "general_" << type_to_name(out);
|
||||||
|
|
||||||
@ -552,7 +513,7 @@ below.
|
|||||||
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
|
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
|
||||||
compute_encoder->setBytes(&beta_, sizeof(float), 4);
|
compute_encoder->setBytes(&beta_, sizeof(float), 4);
|
||||||
|
|
||||||
// Encode shape, strides and ndim
|
// Encode shape, strides and ndim
|
||||||
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
|
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
|
||||||
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
|
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
|
||||||
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
|
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
|
||||||
@ -575,28 +536,25 @@ below.
|
|||||||
|
|
||||||
We can now call the :meth:`axpby` operation on both the CPU and the GPU!
|
We can now call the :meth:`axpby` operation on both the CPU and the GPU!
|
||||||
|
|
||||||
A few things to note about MLX and metal before moving on. MLX keeps track
|
A few things to note about MLX and Metal before moving on. MLX keeps track of
|
||||||
of the active ``compute_encoder``. We rely on :meth:`d.get_command_encoder`
|
the active ``command_buffer`` and the ``MTLCommandBuffer`` to which it is
|
||||||
to give us the active metal compute command encoder instead of building a
|
associated. We rely on :meth:`d.get_command_encoder` to give us the active
|
||||||
new one and calling :meth:`compute_encoder->end_encoding` at the end.
|
metal compute command encoder instead of building a new one and calling
|
||||||
MLX keeps adding kernels (compute pipelines) to the active command encoder
|
:meth:`compute_encoder->end_encoding` at the end. MLX adds kernels (compute
|
||||||
until some specified limit is hit or the compute encoder needs to be flushed
|
pipelines) to the active command buffer until some specified limit is hit or
|
||||||
for synchronization. MLX also handles enqueuing and committing the associated
|
the command buffer needs to be flushed for synchronization.
|
||||||
command buffers as needed. We suggest taking a deeper dive into
|
|
||||||
:class:`metal::Device` if you would like to study this routine further.
|
|
||||||
|
|
||||||
Primitive Transforms
|
Primitive Transforms
|
||||||
^^^^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
Now that we have come this far, let's also learn how to add implementations to
|
Next, let's add implementations for transformations in a :class:`Primitive`.
|
||||||
transformations in a :class:`Primitive`. These transformations can be built on
|
These transformations can be built on top of other operations, including the
|
||||||
top of our operations, including the one we just defined now. Which then gives
|
one we just defined:
|
||||||
us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations.
|
|
||||||
|
|
||||||
.. code-block:: C++
|
.. code-block:: C++
|
||||||
|
|
||||||
/** The Jacobian-vector product. */
|
/** The Jacobian-vector product. */
|
||||||
array Axpby::jvp(
|
std::vector<array> Axpby::jvp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<array>& primals,
|
||||||
const std::vector<array>& tangents,
|
const std::vector<array>& tangents,
|
||||||
const std::vector<int>& argnums) {
|
const std::vector<int>& argnums) {
|
||||||
@ -611,12 +569,12 @@ us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations.
|
|||||||
if (argnums.size() > 1) {
|
if (argnums.size() > 1) {
|
||||||
auto scale = argnums[0] == 0 ? alpha_ : beta_;
|
auto scale = argnums[0] == 0 ? alpha_ : beta_;
|
||||||
auto scale_arr = array(scale, tangents[0].dtype());
|
auto scale_arr = array(scale, tangents[0].dtype());
|
||||||
return multiply(scale_arr, tangents[0], stream());
|
return {multiply(scale_arr, tangents[0], stream())};
|
||||||
}
|
}
|
||||||
// If, argnums = {0, 1}, we take contributions from both
|
// If, argnums = {0, 1}, we take contributions from both
|
||||||
// which gives us jvp = tangent_x * alpha + tangent_y * beta
|
// which gives us jvp = tangent_x * alpha + tangent_y * beta
|
||||||
else {
|
else {
|
||||||
return axpby(tangents[0], tangents[1], alpha_, beta_, stream());
|
return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -625,34 +583,35 @@ us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations.
|
|||||||
/** The vector-Jacobian product. */
|
/** The vector-Jacobian product. */
|
||||||
std::vector<array> Axpby::vjp(
|
std::vector<array> Axpby::vjp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<array>& primals,
|
||||||
const array& cotan,
|
const std::vector<array>& cotangents,
|
||||||
const std::vector<int>& argnums) {
|
const std::vector<int>& argnums,
|
||||||
|
const std::vector<int>& /* unused */) {
|
||||||
// Reverse mode diff
|
// Reverse mode diff
|
||||||
std::vector<array> vjps;
|
std::vector<array> vjps;
|
||||||
for (auto arg : argnums) {
|
for (auto arg : argnums) {
|
||||||
auto scale = arg == 0 ? alpha_ : beta_;
|
auto scale = arg == 0 ? alpha_ : beta_;
|
||||||
auto scale_arr = array(scale, cotan.dtype());
|
auto scale_arr = array(scale, cotangents[0].dtype());
|
||||||
vjps.push_back(multiply(scale_arr, cotan, stream()));
|
vjps.push_back(multiply(scale_arr, cotangents[0], stream()));
|
||||||
}
|
}
|
||||||
return vjps;
|
return vjps;
|
||||||
}
|
}
|
||||||
|
|
||||||
Finally, you need not have a transformation fully defined to start using your
|
Note, a transformation does not need to be fully defined to start using
|
||||||
own :class:`Primitive`.
|
the :class:`Primitive`.
|
||||||
|
|
||||||
.. code-block:: C++
|
.. code-block:: C++
|
||||||
|
|
||||||
/** Vectorize primitive along given axis */
|
/** Vectorize primitive along given axis */
|
||||||
std::pair<array, int> Axpby::vmap(
|
std::pair<std::vector<array>, std::vector<int>> Axpby::vmap(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
throw std::runtime_error("Axpby has no vmap implementation.");
|
throw std::runtime_error("[Axpby] vmap not implemented.");
|
||||||
}
|
}
|
||||||
|
|
||||||
Building and Binding
|
Building and Binding
|
||||||
--------------------
|
--------------------
|
||||||
|
|
||||||
Let's look at the overall directory structure first.
|
Let's look at the overall directory structure first.
|
||||||
|
|
||||||
| extensions
|
| extensions
|
||||||
| ├── axpby
|
| ├── axpby
|
||||||
@ -666,40 +625,39 @@ Let's look at the overall directory structure first.
|
|||||||
| └── setup.py
|
| └── setup.py
|
||||||
|
|
||||||
* ``extensions/axpby/`` defines the C++ extension library
|
* ``extensions/axpby/`` defines the C++ extension library
|
||||||
* ``extensions/mlx_sample_extensions`` sets out the structure for the
|
* ``extensions/mlx_sample_extensions`` sets out the structure for the
|
||||||
associated python package
|
associated Python package
|
||||||
* ``extensions/bindings.cpp`` provides python bindings for our operation
|
* ``extensions/bindings.cpp`` provides Python bindings for our operation
|
||||||
* ``extensions/CMakeLists.txt`` holds CMake rules to build the library and
|
* ``extensions/CMakeLists.txt`` holds CMake rules to build the library and
|
||||||
python bindings
|
Python bindings
|
||||||
* ``extensions/setup.py`` holds the ``setuptools`` rules to build and install
|
* ``extensions/setup.py`` holds the ``setuptools`` rules to build and install
|
||||||
the python package
|
the Python package
|
||||||
|
|
||||||
Binding to Python
|
Binding to Python
|
||||||
^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
We use PyBind11_ to build a Python API for the C++ library. Since bindings for
|
We use nanobind_ to build a Python API for the C++ library. Since bindings for
|
||||||
components such as :class:`mlx.core.array`, :class:`mlx.core.stream`, etc. are
|
components such as :class:`mlx.core.array`, :class:`mlx.core.stream`, etc. are
|
||||||
already provided, adding our :meth:`axpby` is simple!
|
already provided, adding our :meth:`axpby` is simple.
|
||||||
|
|
||||||
.. code-block:: C++
|
.. code-block:: C++
|
||||||
|
|
||||||
PYBIND11_MODULE(mlx_sample_extensions, m) {
|
NB_MODULE(_ext, m) {
|
||||||
m.doc() = "Sample C++ and metal extensions for MLX";
|
m.doc() = "Sample extension for MLX";
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"axpby",
|
"axpby",
|
||||||
&axpby,
|
&axpby,
|
||||||
"x"_a,
|
"x"_a,
|
||||||
"y"_a,
|
"y"_a,
|
||||||
py::pos_only(),
|
|
||||||
"alpha"_a,
|
"alpha"_a,
|
||||||
"beta"_a,
|
"beta"_a,
|
||||||
py::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = py::none(),
|
"stream"_a = nb::none(),
|
||||||
R"pbdoc(
|
R"(
|
||||||
Scale and sum two vectors element-wise
|
Scale and sum two vectors element-wise
|
||||||
``z = alpha * x + beta * y``
|
``z = alpha * x + beta * y``
|
||||||
|
|
||||||
Follows numpy style broadcasting between ``x`` and ``y``
|
Follows numpy style broadcasting between ``x`` and ``y``
|
||||||
Inputs are upcasted to floats if needed
|
Inputs are upcasted to floats if needed
|
||||||
|
|
||||||
@ -711,17 +669,17 @@ already provided, adding our :meth:`axpby` is simple!
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
array: ``alpha * x + beta * y``
|
array: ``alpha * x + beta * y``
|
||||||
)pbdoc");
|
)");
|
||||||
}
|
}
|
||||||
|
|
||||||
Most of the complexity in the above example comes from additional bells and
|
Most of the complexity in the above example comes from additional bells and
|
||||||
whistles such as the literal names and doc-strings.
|
whistles such as the literal names and doc-strings.
|
||||||
|
|
||||||
.. warning::
|
.. warning::
|
||||||
|
|
||||||
:mod:`mlx.core` needs to be imported before importing
|
:mod:`mlx.core` must be imported before importing
|
||||||
:mod:`mlx_sample_extensions` as defined by the pybind11 module above to
|
:mod:`mlx_sample_extensions` as defined by the nanobind module above to
|
||||||
ensure that the casters for :mod:`mlx.core` components like
|
ensure that the casters for :mod:`mlx.core` components like
|
||||||
:class:`mlx.core.array` are available.
|
:class:`mlx.core.array` are available.
|
||||||
|
|
||||||
.. _Building with CMake:
|
.. _Building with CMake:
|
||||||
@ -729,8 +687,8 @@ whistles such as the literal names and doc-strings.
|
|||||||
Building with CMake
|
Building with CMake
|
||||||
^^^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
Building the C++ extension library itself is simple, it only requires that you
|
Building the C++ extension library only requires that you ``find_package(MLX
|
||||||
``find_package(MLX CONFIG)`` and then link it to your library.
|
CONFIG)`` and then link it to your library.
|
||||||
|
|
||||||
.. code-block:: cmake
|
.. code-block:: cmake
|
||||||
|
|
||||||
@ -752,12 +710,12 @@ Building the C++ extension library itself is simple, it only requires that you
|
|||||||
# Link to mlx
|
# Link to mlx
|
||||||
target_link_libraries(mlx_ext PUBLIC mlx)
|
target_link_libraries(mlx_ext PUBLIC mlx)
|
||||||
|
|
||||||
We also need to build the attached metal library. For convenience, we provide a
|
We also need to build the attached Metal library. For convenience, we provide a
|
||||||
:meth:`mlx_build_metallib` function that builds a ``.metallib`` target given
|
:meth:`mlx_build_metallib` function that builds a ``.metallib`` target given
|
||||||
sources, headers, destinations, etc. (defined in ``cmake/extension.cmake`` and
|
sources, headers, destinations, etc. (defined in ``cmake/extension.cmake`` and
|
||||||
automatically imported with MLX package).
|
automatically imported with MLX package).
|
||||||
|
|
||||||
Here is what that looks like in practice!
|
Here is what that looks like in practice:
|
||||||
|
|
||||||
.. code-block:: cmake
|
.. code-block:: cmake
|
||||||
|
|
||||||
@ -779,27 +737,29 @@ Here is what that looks like in practice!
|
|||||||
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
Finally, we build the Pybind11_ bindings
|
Finally, we build the nanobind_ bindings
|
||||||
|
|
||||||
.. code-block:: cmake
|
.. code-block:: cmake
|
||||||
|
|
||||||
pybind11_add_module(
|
nanobind_add_module(
|
||||||
mlx_sample_extensions
|
_ext
|
||||||
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
|
NB_STATIC STABLE_ABI LTO NOMINSIZE
|
||||||
|
NB_DOMAIN mlx
|
||||||
|
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
|
||||||
)
|
)
|
||||||
target_link_libraries(mlx_sample_extensions PRIVATE mlx_ext)
|
target_link_libraries(_ext PRIVATE mlx_ext)
|
||||||
|
|
||||||
if(BUILD_SHARED_LIBS)
|
if(BUILD_SHARED_LIBS)
|
||||||
target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path)
|
target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
Building with ``setuptools``
|
Building with ``setuptools``
|
||||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
Once we have set out the CMake build rules as described above, we can use the
|
Once we have set out the CMake build rules as described above, we can use the
|
||||||
build utilities defined in :mod:`mlx.extension` for a simple build process.
|
build utilities defined in :mod:`mlx.extension`:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from mlx import extension
|
from mlx import extension
|
||||||
from setuptools import setup
|
from setuptools import setup
|
||||||
@ -809,48 +769,50 @@ build utilities defined in :mod:`mlx.extension` for a simple build process.
|
|||||||
name="mlx_sample_extensions",
|
name="mlx_sample_extensions",
|
||||||
version="0.0.0",
|
version="0.0.0",
|
||||||
description="Sample C++ and Metal extensions for MLX primitives.",
|
description="Sample C++ and Metal extensions for MLX primitives.",
|
||||||
ext_modules=[extension.CMakeExtension("mlx_sample_extensions")],
|
ext_modules=[extension.CMakeExtension("mlx_sample_extensions._ext")],
|
||||||
cmdclass={"build_ext": extension.CMakeBuild},
|
cmdclass={"build_ext": extension.CMakeBuild},
|
||||||
packages = ["mlx_sample_extensions"],
|
packages=["mlx_sample_extensions"],
|
||||||
package_dir = {"": "mlx_sample_extensions"},
|
package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
|
||||||
package_data = {"mlx_sample_extensions" : ["*.so", "*.dylib", "*.metallib"]},
|
extras_require={"dev":[]},
|
||||||
zip_safe=False,
|
zip_safe=False,
|
||||||
python_requires=">=3.7",
|
python_requires=">=3.8",
|
||||||
)
|
)
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
We treat ``extensions/mlx_sample_extensions`` as the package directory
|
We treat ``extensions/mlx_sample_extensions`` as the package directory
|
||||||
even though it only contains a ``__init__.py`` to ensure the following:
|
even though it only contains a ``__init__.py`` to ensure the following:
|
||||||
|
|
||||||
* :mod:`mlx.core` is always imported before importing :mod:`mlx_sample_extensions`
|
|
||||||
* The C++ extension library and the metal library are co-located with the python
|
|
||||||
bindings and copied together if the package is installed
|
|
||||||
|
|
||||||
You can build inplace for development using
|
* :mod:`mlx.core` must be imported before importing :mod:`_ext`
|
||||||
|
* The C++ extension library and the metal library are co-located with the python
|
||||||
|
bindings and copied together if the package is installed
|
||||||
|
|
||||||
|
To build the package, first install the build dependencies with ``pip install
|
||||||
|
-r requirements.txt``. You can then build inplace for development using
|
||||||
``python setup.py build_ext -j8 --inplace`` (in ``extensions/``)
|
``python setup.py build_ext -j8 --inplace`` (in ``extensions/``)
|
||||||
|
|
||||||
This will result in a directory structure as follows:
|
This results in the directory structure:
|
||||||
|
|
||||||
| extensions
|
| extensions
|
||||||
| ├── mlx_sample_extensions
|
| ├── mlx_sample_extensions
|
||||||
| │ ├── __init__.py
|
| │ ├── __init__.py
|
||||||
| │ ├── libmlx_ext.dylib # C++ extension library
|
| │ ├── libmlx_ext.dylib # C++ extension library
|
||||||
| │ ├── mlx_ext.metallib # Metal library
|
| │ ├── mlx_ext.metallib # Metal library
|
||||||
| │ └── mlx_sample_extensions.cpython-3x-darwin.so # Python Binding
|
| │ └── _ext.cpython-3x-darwin.so # Python Binding
|
||||||
| ...
|
| ...
|
||||||
|
|
||||||
When you try to install using the command ``python -m pip install .``
|
When you try to install using the command ``python -m pip install .`` (in
|
||||||
(in ``extensions/``), the package will be installed with the same structure as
|
``extensions/``), the package will be installed with the same structure as
|
||||||
``extensions/mlx_sample_extensions`` and the C++ and metal library will be
|
``extensions/mlx_sample_extensions`` and the C++ and Metal library will be
|
||||||
copied along with the python binding since they are specified as ``package_data``.
|
copied along with the Python binding since they are specified as
|
||||||
|
``package_data``.
|
||||||
|
|
||||||
Usage
|
Usage
|
||||||
-----
|
-----
|
||||||
|
|
||||||
After installing the extension as described above, you should be able to simply
|
After installing the extension as described above, you should be able to simply
|
||||||
import the python package and play with it as you would any other MLX operation!
|
import the Python package and play with it as you would any other MLX operation.
|
||||||
|
|
||||||
Let's looks at a simple script and it's results!
|
Let's look at a simple script and its results:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
@ -874,12 +836,12 @@ Output:
|
|||||||
c correctness: True
|
c correctness: True
|
||||||
|
|
||||||
Results
|
Results
|
||||||
^^^^^^^^^^^^^^^^
|
^^^^^^^
|
||||||
|
|
||||||
Let's run a quick benchmark and see how our new ``axpby`` operation compares
|
Let's run a quick benchmark and see how our new ``axpby`` operation compares
|
||||||
with the naive :meth:`simple_axpby` we defined at first on the CPU.
|
with the naive :meth:`simple_axpby` we first defined on the CPU.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from mlx_sample_extensions import axpby
|
from mlx_sample_extensions import axpby
|
||||||
@ -898,7 +860,7 @@ with the naive :meth:`simple_axpby` we defined at first on the CPU.
|
|||||||
alpha = 4.0
|
alpha = 4.0
|
||||||
beta = 2.0
|
beta = 2.0
|
||||||
|
|
||||||
mx.eval((x, y))
|
mx.eval(x, y)
|
||||||
|
|
||||||
def bench(f):
|
def bench(f):
|
||||||
# Warm up
|
# Warm up
|
||||||
@ -919,30 +881,23 @@ with the naive :meth:`simple_axpby` we defined at first on the CPU.
|
|||||||
|
|
||||||
print(f"Simple axpby: {simple_time:.3f} s | Custom axpby: {custom_time:.3f} s")
|
print(f"Simple axpby: {simple_time:.3f} s | Custom axpby: {custom_time:.3f} s")
|
||||||
|
|
||||||
Results:
|
The results are ``Simple axpby: 0.114 s | Custom axpby: 0.109 s``. We see
|
||||||
|
modest improvements right away!
|
||||||
.. code-block::
|
|
||||||
|
|
||||||
Simple axpby: 0.114 s | Custom axpby: 0.109 s
|
|
||||||
|
|
||||||
We see some modest improvements right away!
|
|
||||||
|
|
||||||
This operation is now good to be used to build other operations, in
|
This operation is now good to be used to build other operations, in
|
||||||
:class:`mlx.nn.Module` calls, and also as a part of graph transformations like
|
:class:`mlx.nn.Module` calls, and also as a part of graph transformations like
|
||||||
:meth:`grad`!
|
:meth:`grad`.
|
||||||
|
|
||||||
Scripts
|
Scripts
|
||||||
-------
|
-------
|
||||||
|
|
||||||
.. admonition:: Download the code
|
.. admonition:: Download the code
|
||||||
|
|
||||||
The full example code is available in `mlx <code>`_.
|
The full example code is available in `mlx <https://github.com/ml-explore/mlx/tree/main/examples/extensions/>`_.
|
||||||
|
|
||||||
.. code: `https://github.com/ml-explore/mlx/tree/main/examples/extensions/`_
|
|
||||||
|
|
||||||
.. _Accelerate: https://developer.apple.com/documentation/accelerate/blas?language=objc
|
.. _Accelerate: https://developer.apple.com/documentation/accelerate/blas?language=objc
|
||||||
.. _Metal: https://developer.apple.com/documentation/metal?language=objc
|
.. _Metal: https://developer.apple.com/documentation/metal?language=objc
|
||||||
.. _Metal-cpp: https://developer.apple.com/metal/cpp/
|
.. _Metal-cpp: https://developer.apple.com/metal/cpp/
|
||||||
.. _`Metal Specification`: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
.. _`Metal Specification`: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
||||||
.. _`Metal Example`: https://developer.apple.com/documentation/metal/performing_calculations_on_a_gpu?language=objc
|
.. _`Metal Example`: https://developer.apple.com/documentation/metal/performing_calculations_on_a_gpu?language=objc
|
||||||
.. _PyBind11: https://pybind11.readthedocs.io/en/stable/
|
.. _nanobind: https://nanobind.readthedocs.io/en/latest/
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
cmake_minimum_required(VERSION 3.27)
|
cmake_minimum_required(VERSION 3.27)
|
||||||
|
|
||||||
project(mlx_sample_extensions LANGUAGES CXX)
|
project(_ext LANGUAGES CXX)
|
||||||
|
|
||||||
# ----------------------------- Setup -----------------------------
|
# ----------------------------- Setup -----------------------------
|
||||||
set(CMAKE_CXX_STANDARD 17)
|
set(CMAKE_CXX_STANDARD 17)
|
||||||
@ -11,8 +11,12 @@ option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON)
|
|||||||
|
|
||||||
# ----------------------------- Dependencies -----------------------------
|
# ----------------------------- Dependencies -----------------------------
|
||||||
find_package(MLX CONFIG REQUIRED)
|
find_package(MLX CONFIG REQUIRED)
|
||||||
find_package(Python COMPONENTS Interpreter Development)
|
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
|
||||||
find_package(pybind11 CONFIG REQUIRED)
|
execute_process(
|
||||||
|
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
|
||||||
|
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR)
|
||||||
|
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
|
||||||
|
find_package(nanobind CONFIG REQUIRED)
|
||||||
|
|
||||||
# ----------------------------- Extensions -----------------------------
|
# ----------------------------- Extensions -----------------------------
|
||||||
|
|
||||||
@ -38,7 +42,6 @@ target_link_libraries(mlx_ext PUBLIC mlx)
|
|||||||
|
|
||||||
# Build metallib
|
# Build metallib
|
||||||
if(MLX_BUILD_METAL)
|
if(MLX_BUILD_METAL)
|
||||||
|
|
||||||
mlx_build_metallib(
|
mlx_build_metallib(
|
||||||
TARGET mlx_ext_metallib
|
TARGET mlx_ext_metallib
|
||||||
TITLE mlx_ext
|
TITLE mlx_ext
|
||||||
@ -54,13 +57,15 @@ if(MLX_BUILD_METAL)
|
|||||||
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# ----------------------------- Pybind -----------------------------
|
# ----------------------------- Python Bindings -----------------------------
|
||||||
pybind11_add_module(
|
nanobind_add_module(
|
||||||
mlx_sample_extensions
|
_ext
|
||||||
|
NB_STATIC STABLE_ABI LTO NOMINSIZE
|
||||||
|
NB_DOMAIN mlx
|
||||||
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
|
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
|
||||||
)
|
)
|
||||||
target_link_libraries(mlx_sample_extensions PRIVATE mlx_ext)
|
target_link_libraries(_ext PRIVATE mlx_ext)
|
||||||
|
|
||||||
if(BUILD_SHARED_LIBS)
|
if(BUILD_SHARED_LIBS)
|
||||||
target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path)
|
target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path)
|
||||||
endif()
|
endif()
|
||||||
|
18
examples/extensions/README.md
Normal file
18
examples/extensions/README.md
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
|
||||||
|
## Build the extensions
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
For faster builds during development, you can also pre-install the requirements:
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
And then run:
|
||||||
|
|
||||||
|
```
|
||||||
|
python setup.py build_ext -j8 --inplace
|
||||||
|
```
|
@ -1,4 +1,4 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
@ -43,7 +43,7 @@ array axpby(
|
|||||||
auto promoted_dtype = promote_types(x.dtype(), y.dtype());
|
auto promoted_dtype = promote_types(x.dtype(), y.dtype());
|
||||||
|
|
||||||
// Upcast to float32 for non-floating point inputs x and y
|
// Upcast to float32 for non-floating point inputs x and y
|
||||||
auto out_dtype = is_floating_point(promoted_dtype)
|
auto out_dtype = issubdtype(promoted_dtype, float32)
|
||||||
? promoted_dtype
|
? promoted_dtype
|
||||||
: promote_types(promoted_dtype, float32);
|
: promote_types(promoted_dtype, float32);
|
||||||
|
|
||||||
@ -106,12 +106,12 @@ void axpby_impl(
|
|||||||
/** Fall back implementation for evaluation on CPU */
|
/** Fall back implementation for evaluation on CPU */
|
||||||
void Axpby::eval(
|
void Axpby::eval(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& out_arr) {
|
std::vector<array>& outputs) {
|
||||||
auto out = out_arr[0];
|
|
||||||
// Check the inputs (registered in the op while constructing the out array)
|
// Check the inputs (registered in the op while constructing the out array)
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& x = inputs[0];
|
auto& x = inputs[0];
|
||||||
auto& y = inputs[1];
|
auto& y = inputs[1];
|
||||||
|
auto& out = outputs[0];
|
||||||
|
|
||||||
// Dispatch to the correct dtype
|
// Dispatch to the correct dtype
|
||||||
if (out.dtype() == float32) {
|
if (out.dtype() == float32) {
|
||||||
@ -150,11 +150,7 @@ void axpby_impl_accelerate(
|
|||||||
// The data in the output array is allocated to match the strides in y
|
// The data in the output array is allocated to match the strides in y
|
||||||
// such that x, y, and out are contiguous in the same mode and
|
// such that x, y, and out are contiguous in the same mode and
|
||||||
// no transposition is needed
|
// no transposition is needed
|
||||||
out.set_data(
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
allocator::malloc_or_wait(y.data_size() * out.itemsize()),
|
|
||||||
y.data_size(),
|
|
||||||
y.strides(),
|
|
||||||
y.flags());
|
|
||||||
|
|
||||||
// We then copy over the elements using the contiguous vector specialization
|
// We then copy over the elements using the contiguous vector specialization
|
||||||
copy_inplace(y, out, CopyType::Vector);
|
copy_inplace(y, out, CopyType::Vector);
|
||||||
@ -180,11 +176,11 @@ void axpby_impl_accelerate(
|
|||||||
/** Evaluate primitive on CPU using accelerate specializations */
|
/** Evaluate primitive on CPU using accelerate specializations */
|
||||||
void Axpby::eval_cpu(
|
void Axpby::eval_cpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outarr) {
|
std::vector<array>& outputs) {
|
||||||
auto out = outarr[0];
|
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& x = inputs[0];
|
auto& x = inputs[0];
|
||||||
auto& y = inputs[1];
|
auto& y = inputs[1];
|
||||||
|
auto& out = outputs[0];
|
||||||
|
|
||||||
// Accelerate specialization for contiguous single precision float arrays
|
// Accelerate specialization for contiguous single precision float arrays
|
||||||
if (out.dtype() == float32 &&
|
if (out.dtype() == float32 &&
|
||||||
@ -195,7 +191,7 @@ void Axpby::eval_cpu(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Fall back to common backend if specializations are not available
|
// Fall back to common backend if specializations are not available
|
||||||
eval(inputs, outarr);
|
eval(inputs, outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // Accelerate not available
|
#else // Accelerate not available
|
||||||
@ -203,8 +199,8 @@ void Axpby::eval_cpu(
|
|||||||
/** Evaluate primitive on CPU falling back to common backend */
|
/** Evaluate primitive on CPU falling back to common backend */
|
||||||
void Axpby::eval_cpu(
|
void Axpby::eval_cpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& out) {
|
const std::vector<array>& outputs) {
|
||||||
eval(inputs, out);
|
eval(inputs, outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
@ -218,12 +214,12 @@ void Axpby::eval_cpu(
|
|||||||
/** Evaluate primitive on GPU */
|
/** Evaluate primitive on GPU */
|
||||||
void Axpby::eval_gpu(
|
void Axpby::eval_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outarr) {
|
std::vector<array>& outputs) {
|
||||||
// Prepare inputs
|
// Prepare inputs
|
||||||
auto out = outarr[0];
|
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& x = inputs[0];
|
auto& x = inputs[0];
|
||||||
auto& y = inputs[1];
|
auto& y = inputs[1];
|
||||||
|
auto& out = outputs[0];
|
||||||
|
|
||||||
// Each primitive carries the stream it should execute on
|
// Each primitive carries the stream it should execute on
|
||||||
// and each stream carries its device identifiers
|
// and each stream carries its device identifiers
|
||||||
@ -372,4 +368,4 @@ bool Axpby::is_equivalent(const Primitive& other) const {
|
|||||||
return alpha_ == r_other.alpha_ && beta_ == r_other.beta_;
|
return alpha_ == r_other.alpha_ && beta_ == r_other.beta_;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -42,9 +42,9 @@ class Axpby : public Primitive {
|
|||||||
* To avoid unnecessary allocations, the evaluation function
|
* To avoid unnecessary allocations, the evaluation function
|
||||||
* is responsible for allocating space for the array.
|
* is responsible for allocating space for the array.
|
||||||
*/
|
*/
|
||||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& out)
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
override;
|
override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& out)
|
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
override;
|
override;
|
||||||
|
|
||||||
/** The Jacobian-vector product. */
|
/** The Jacobian-vector product. */
|
||||||
@ -83,7 +83,7 @@ class Axpby : public Primitive {
|
|||||||
float beta_;
|
float beta_;
|
||||||
|
|
||||||
/** Fall back implementation for evaluation on CPU */
|
/** Fall back implementation for evaluation on CPU */
|
||||||
void eval(const std::vector<array>& inputs, std::vector<array>& out);
|
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -1,31 +1,31 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include <pybind11/pybind11.h>
|
#include <nanobind/nanobind.h>
|
||||||
#include <pybind11/stl.h>
|
#include <nanobind/stl/variant.h>
|
||||||
|
|
||||||
#include "axpby/axpby.h"
|
#include "axpby/axpby.h"
|
||||||
|
|
||||||
namespace py = pybind11;
|
namespace nb = nanobind;
|
||||||
using namespace py::literals;
|
using namespace nb::literals;
|
||||||
|
|
||||||
using namespace mlx::core;
|
using namespace mlx::core;
|
||||||
|
|
||||||
PYBIND11_MODULE(mlx_sample_extensions, m) {
|
NB_MODULE(_ext, m) {
|
||||||
m.doc() = "Sample C++ and metal extensions for MLX";
|
m.doc() = "Sample extension for MLX";
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"axpby",
|
"axpby",
|
||||||
&axpby,
|
&axpby,
|
||||||
"x"_a,
|
"x"_a,
|
||||||
"y"_a,
|
"y"_a,
|
||||||
py::pos_only(),
|
|
||||||
"alpha"_a,
|
"alpha"_a,
|
||||||
"beta"_a,
|
"beta"_a,
|
||||||
py::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = py::none(),
|
"stream"_a = nb::none(),
|
||||||
R"pbdoc(
|
R"(
|
||||||
Scale and sum two vectors element-wise
|
Scale and sum two vectors element-wise
|
||||||
``z = alpha * x + beta * y``
|
``z = alpha * x + beta * y``
|
||||||
|
|
||||||
Follows numpy style broadcasting between ``x`` and ``y``
|
Follows numpy style broadcasting between ``x`` and ``y``
|
||||||
Inputs are upcasted to floats if needed
|
Inputs are upcasted to floats if needed
|
||||||
|
|
||||||
@ -37,5 +37,5 @@ PYBIND11_MODULE(mlx_sample_extensions, m) {
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
array: ``alpha * x + beta * y``
|
array: ``alpha * x + beta * y``
|
||||||
)pbdoc");
|
)");
|
||||||
}
|
}
|
||||||
|
@ -1,3 +1,8 @@
|
|||||||
[build-system]
|
[build-system]
|
||||||
requires = ["setuptools>=42", "pybind11>=2.10", "cmake>=3.24", "mlx @ git+https://github.com/mlx-explore/mlx@main"]
|
requires = [
|
||||||
build-backend = "setuptools.build_meta"
|
"setuptools>=42",
|
||||||
|
"cmake>=3.24",
|
||||||
|
"mlx>=0.9.0",
|
||||||
|
"nanobind@git+https://github.com/wjakob/nanobind.git#egg=4148debcf91f5ccab0c3b8d67b5c3cabd61f407f",
|
||||||
|
]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
4
examples/extensions/requirements.txt
Normal file
4
examples/extensions/requirements.txt
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
setuptools>=42
|
||||||
|
cmake>=3.24
|
||||||
|
mlx>=0.9.0
|
||||||
|
nanobind@git+https://github.com/wjakob/nanobind.git#egg=4148debcf91f5ccab0c3b8d67b5c3cabd61f407f
|
@ -1,4 +1,4 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from setuptools import setup
|
from setuptools import setup
|
||||||
|
|
||||||
@ -9,11 +9,11 @@ if __name__ == "__main__":
|
|||||||
name="mlx_sample_extensions",
|
name="mlx_sample_extensions",
|
||||||
version="0.0.0",
|
version="0.0.0",
|
||||||
description="Sample C++ and Metal extensions for MLX primitives.",
|
description="Sample C++ and Metal extensions for MLX primitives.",
|
||||||
ext_modules=[extension.CMakeExtension("mlx_sample_extensions")],
|
ext_modules=[extension.CMakeExtension("mlx_sample_extensions._ext")],
|
||||||
cmdclass={"build_ext": extension.CMakeBuild},
|
cmdclass={"build_ext": extension.CMakeBuild},
|
||||||
packages=["mlx_sample_extensions"],
|
packages=["mlx_sample_extensions"],
|
||||||
package_dir={"": "."},
|
|
||||||
package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
|
package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
|
||||||
|
extras_require={"dev": []},
|
||||||
zip_safe=False,
|
zip_safe=False,
|
||||||
python_requires=">=3.8",
|
python_requires=">=3.8",
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user