diff --git a/docs/src/dev/extensions.rst b/docs/src/dev/extensions.rst index 9198548a4..acf41a773 100644 --- a/docs/src/dev/extensions.rst +++ b/docs/src/dev/extensions.rst @@ -1,24 +1,16 @@ Developer Documentation ======================= -MLX provides a open and flexible backend to which users may add operations -and specialized implementations without much hassle. While the library supplies -efficient operations that can be used and composed for any number of -applications, there may arise cases where new functionalities or highly -optimized implementations are needed. For such cases, you may design and -implement your own operations that link to and build on top of :mod:`mlx.core`. -We will introduce the inner-workings of MLX and go over a simple example to -learn the steps involved in adding new operations to MLX with your own CPU -and GPU implementations. +You can extend MLX with custom operations on the CPU or GPU. This guide +explains how to do that with a simple example. Introducing the Example ----------------------- -Let's say that you would like an operation that takes in two arrays, -``x`` and ``y``, scales them both by some coefficients ``alpha`` and ``beta`` -respectively, and then adds them together to get the result -``z = alpha * x + beta * y``. Well, you can very easily do that by just -writing out a function as follows: +Let's say you would like an operation that takes in two arrays, ``x`` and +``y``, scales them both by coefficients ``alpha`` and ``beta`` respectively, +and then adds them together to get the result ``z = alpha * x + beta * y``. +You can do that in MLX directly: .. code-block:: python @@ -27,44 +19,35 @@ writing out a function as follows: def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array: return alpha * x + beta * y -This function performs that operation while leaving the implementations and -differentiation to MLX. +This function performs that operation while leaving the implementation and +function transformations to MLX. -However, you work with vector math libraries often and realize that the -``axpby`` routine defines the same operation ``Y = (alpha * X) + (beta * Y)``. -You would really like the part of your applications that does this operation -on the CPU to be very fast - so you decide that you want it to rely on the -``axpby`` routine provided by the Accelerate_ framework. Continuing to impose -our assumptions on to you, let's also assume that you want to learn how to add -your own implementation for the gradients of your new operation while going -over the ins-and-outs of the MLX framework. +However you may need to customize the underlying implementation, perhaps to +make it faster or for custom differentiation. In this tutorial we will go +through adding custom extensions. It will cover: -Well, what a coincidence! You are in the right place. Over the course of this -example, we will learn: - -* The structure of the MLX library from the frontend API to the backend implementations. -* How to implement your own CPU backend that redirects to Accelerate_ when appropriate (and a fallback if needed). -* How to implement your own GPU implementation using metal. -* How to add your own ``vjp`` and ``jvp``. -* How to build your implementations, link them to MLX, and bind them to python. +* The structure of the MLX library. +* Implementing a CPU operation that redirects to Accelerate_ when appropriate. +* Implementing a GPU operation using metal. +* Adding the ``vjp`` and ``jvp`` function transformation. +* Building a custom extension and binding it to python. Operations and Primitives ------------------------- -In one sentence, operations in MLX build the computation graph, and primitives -provide the rules for evaluation and transformations of said graph. Let's start -by discussing operations in more detail. +Operations in MLX build the computation graph. Primitives provide the rules for +evaluating and transforming the graph. Let's start by discussing operations in +more detail. Operations ^^^^^^^^^^^ -Operations are the frontend functions that operate on arrays. They are defined -in the C++ API (:ref:`cpp_ops`) and then we provide bindings to these -operations in the Python API (:ref:`ops`). +Operations are the front-end functions that operate on arrays. They are defined +in the C++ API (:ref:`cpp_ops`), and the Python API (:ref:`ops`) binds them. -We would like an operation, :meth:`axpby` that takes in two arrays ``x`` and ``y``, -and two scalars, ``alpha`` and ``beta``. This is how we would define it in the -C++ API: +We would like an operation, :meth:`axpby` that takes in two arrays ``x`` and +``y``, and two scalars, ``alpha`` and ``beta``. This is how to define it in +C++: .. code-block:: C++ @@ -83,10 +66,7 @@ C++ API: StreamOrDevice s = {} // Stream on which to schedule the operation ); - -This operation itself can call other operations within it if needed. So, the -simplest way to go about implementing this operation would be do so in terms -of existing operations. +The simplest way to this operation is in terms of existing operations: .. code-block:: C++ @@ -100,25 +80,23 @@ of existing operations. // Scale x and y on the provided stream auto ax = multiply(array(alpha), x, s); auto by = multiply(array(beta), y, s); - + // Add and return return add(ax, by, s); } -However, as we discussed earlier, this is not our goal. The operations themselves -do not contain the implementations that act on the data, nor do they contain the -rules of transformations. Rather, they are an easy to use interface that build -on top of the building blocks we call :class:`Primitive`. +The operations themselves do not contain the implementations that act on the +data, nor do they contain the rules of transformations. Rather, they are an +easy to use interface that use :class:`Primitive` building blocks. Primitives ^^^^^^^^^^^ -A :class:`Primitive` is part of the computation graph of an :class:`array`. It -defines how to create an output given a set of input :class:`array` . Further, -a :class:`Primitive` is a class that contains rules on how it is evaluated -on the CPU or GPU, and how it acts under transformations such as ``vjp`` and -``jvp``. These words on their own can be a bit abstract, so lets take a step -back and go to our example to give ourselves a more concrete image. +A :class:`Primitive` is part of the computation graph of an :class:`array`. It +defines how to create outputs arrays given a input arrays. Further, a +:class:`Primitive` has methods to run on the CPU or GPU and for function +transformations such as ``vjp`` and ``jvp``. Lets go back to our example to be +more concrete: .. code-block:: C++ @@ -134,11 +112,15 @@ back and go to our example to give ourselves a more concrete image. * To avoid unnecessary allocations, the evaluation function * is responsible for allocating space for the array. */ - void eval_cpu(const std::vector& inputs, array& out) override; - void eval_gpu(const std::vector& inputs, array& out) override; + void eval_cpu( + const std::vector& inputs, + std::vector& outputs) override; + void eval_gpu( + const std::vector& inputs, + std::vector& outputs) override; /** The Jacobian-vector product. */ - array jvp( + std::vector jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) override; @@ -147,7 +129,8 @@ back and go to our example to give ourselves a more concrete image. std::vector vjp( const std::vector& primals, const array& cotan, - const std::vector& argnums) override; + const std::vector& argnums, + const std::vector& outputs) override; /** * The primitive must know how to vectorize itself across @@ -155,7 +138,7 @@ back and go to our example to give ourselves a more concrete image. * representing the vectorized computation and the axis which * corresponds to the output vectorized dimension. */ - std::pair vmap( + virtual std::pair, std::vector> vmap( const std::vector& inputs, const std::vector& axes) override; @@ -175,22 +158,22 @@ back and go to our example to give ourselves a more concrete image. void eval(const std::vector& inputs, array& out); }; -The :class:`Axpby` class derives from the base :class:`Primitive` class and -follows the above demonstrated interface. :class:`Axpby` treats ``alpha`` and -``beta`` as parameters. It then provides implementations of how the array ``out`` -is produced given ``inputs`` through :meth:`Axpby::eval_cpu` and -:meth:`Axpby::eval_gpu`. Further, it provides rules of transformations in -:meth:`Axpby::jvp`, :meth:`Axpby::vjp`, and :meth:`Axpby::vmap`. +The :class:`Axpby` class derives from the base :class:`Primitive` class. The +:class:`Axpby` treats ``alpha`` and ``beta`` as parameters. It then provides +implementations of how the output array is produced given the inputs through +:meth:`Axpby::eval_cpu` and :meth:`Axpby::eval_gpu`. It also provides rules +of transformations in :meth:`Axpby::jvp`, :meth:`Axpby::vjp`, and +:meth:`Axpby::vmap`. -Using the Primitives -^^^^^^^^^^^^^^^^^^^^^ +Using the Primitive +^^^^^^^^^^^^^^^^^^^ -Operations can use this :class:`Primitive` to add a new :class:`array` to -the computation graph. An :class:`array` can be constructed by providing its -data type, shape, the :class:`Primitive` that computes it, and the -:class:`array` inputs that are passed to the primitive. +Operations can use this :class:`Primitive` to add a new :class:`array` to the +computation graph. An :class:`array` can be constructed by providing its data +type, shape, the :class:`Primitive` that computes it, and the :class:`array` +inputs that are passed to the primitive. -Let's re-implement our operation now in terms of our :class:`Axpby` primitive. +Let's reimplement our operation now in terms of our :class:`Axpby` primitive. .. code-block:: C++ @@ -238,27 +221,26 @@ This operation now handles the following: Implementing the Primitive -------------------------- -No computation happens when we call the operation alone. In effect, the -operation only builds the computation graph. When we evaluate the output -array, MLX schedules the execution of the computation graph, and calls -:meth:`Axpby::eval_cpu` or :meth:`Axpby::eval_gpu` depending on the -stream/device specified by the user. +No computation happens when we call the operation alone. The operation only +builds the computation graph. When we evaluate the output array, MLX schedules +the execution of the computation graph, and calls :meth:`Axpby::eval_cpu` or +:meth:`Axpby::eval_gpu` depending on the stream/device specified by the user. .. warning:: When :meth:`Primitive::eval_cpu` or :meth:`Primitive::eval_gpu` are called, no memory has been allocated for the output array. It falls on the implementation - of these functions to allocate memory as needed + of these functions to allocate memory as needed. -Implementing the CPU Backend +Implementing the CPU Back-end ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Let's start by trying to implement a naive and generic version of -:meth:`Axpby::eval_cpu`. We declared this as a private member function of -:class:`Axpby` earlier called :meth:`Axpby::eval`. +Let's start by implementing a naive and generic version of +:meth:`Axpby::eval_cpu`. We declared this as a private member function of +:class:`Axpby` earlier called :meth:`Axpby::eval`. -Our naive method will go over each element of the output array, find the -corresponding input elements of ``x`` and ``y`` and perform the operation -pointwise. This is captured in the templated function :meth:`axpby_impl`. +Our naive method will go over each element of the output array, find the +corresponding input elements of ``x`` and ``y`` and perform the operation +point-wise. This is captured in the templated function :meth:`axpby_impl`. .. code-block:: C++ @@ -296,19 +278,19 @@ pointwise. This is captured in the templated function :meth:`axpby_impl`. } } -Now, we would like our implementation to be able to do this pointwise operation -for all incoming floating point arrays. Accordingly, we add dispatches for -``float32``, ``float16``, ``bfloat16`` and ``complex64``. We throw an error -if we encounter an unexpected type. +Our implementation should work for all incoming floating point arrays. +Accordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and +``complex64``. We throw an error if we encounter an unexpected type. .. code-block:: C++ /** Fall back implementation for evaluation on CPU */ - void Axpby::eval(const std::vector& inputs, array& out) { - // Check the inputs (registered in the op while constructing the out array) - assert(inputs.size() == 2); + void Axpby::eval( + const std::vector& inputs, + const std::vector& outputs) { auto& x = inputs[0]; auto& y = inputs[1]; + auto& out = outputs[0]; // Dispatch to the correct dtype if (out.dtype() == float32) { @@ -321,28 +303,26 @@ if we encounter an unexpected type. return axpby_impl(x, y, out, alpha_, beta_); } else { throw std::runtime_error( - "Axpby is only supported for floating point types."); + "[Axpby] Only supports floating point types."); } } -We have a fallback implementation! Now, to do what we are really here to do. -Remember we wanted to use the ``axpby`` routine provided by the Accelerate_ -framework? Well, there are 3 complications to keep in mind: +This is good as a fallback implementation. We can use the ``axpby`` routine +provided by the Accelerate_ framework for a faster implementation in certain +cases: #. Accelerate does not provide implementations of ``axpby`` for half precision - floats. We can only direct to it for ``float32`` types -#. Accelerate assumes the inputs ``x`` and ``y`` are contiguous and all elements - have fixed strides between them. Possibly due to broadcasts and transposes, - we aren't guaranteed that the inputs fit this requirement. We can - only direct to Accelerate if both ``x`` and ``y`` are row contiguous or - column contiguous. -#. Accelerate performs the routine ``Y = (alpha * X) + (beta * Y)`` inplace. - MLX expects to write out the answer to a new array. We must copy the elements - of ``y`` into the output array and use that as an input to ``axpby`` + floats. We can only use it for ``float32`` types. +#. Accelerate assumes the inputs ``x`` and ``y`` are contiguous and all + elements have fixed strides between them. We only direct to Accelerate + if both ``x`` and ``y`` are row contiguous or column contiguous. +#. Accelerate performs the routine ``Y = (alpha * X) + (beta * Y)`` in-place. + MLX expects to write the output to a new array. We must copy the elements + of ``y`` into the output and use that as an input to ``axpby``. -Let's write out an implementation that uses Accelerate in the right conditions. -It must simply allocate data for the output, copy elements of ``y`` into it, -and then call the :meth:`catlas_saxpby` from accelerate. +Let's write an implementation that uses Accelerate in the right conditions. +It allocates data for the output, copies ``y`` into it, and then calls the +:func:`catlas_saxpby` from accelerate. .. code-block:: C++ @@ -356,17 +336,7 @@ and then call the :meth:`catlas_saxpby` from accelerate. // Accelerate library provides catlas_saxpby which does // Y = (alpha * X) + (beta * Y) in place // To use it, we first copy the data in y over to the output array - - // This specialization requires both x and y be contiguous in the same mode - // i.e: corresponding linear indices in both point to corresponding elements - // The data in the output array is allocated to match the strides in y - // such that x, y, and out are contiguous in the same mode and - // no transposition is needed - out.set_data( - allocator::malloc_or_wait(y.data_size() * out.itemsize()), - y.data_size(), - y.strides(), - y.flags()); + out.set_data(allocator::malloc_or_wait(out.nbytes())); // We then copy over the elements using the contiguous vector specialization copy_inplace(y, out, CopyType::Vector); @@ -389,18 +359,20 @@ and then call the :meth:`catlas_saxpby` from accelerate. /* INCY = */ 1); } -Great! But what about the inputs that do not fit the criteria for accelerate? -Luckily, we can always just direct back to :meth:`Axpby::eval`. - -With this in mind, lets finally implement our :meth:`Axpby::eval_cpu`. +For inputs that do not fit the criteria for accelerate, we fall back to +:meth:`Axpby::eval`. With this in mind, let's finish our +:meth:`Axpby::eval_cpu`. .. code-block:: C++ /** Evaluate primitive on CPU using accelerate specializations */ - void Axpby::eval_cpu(const std::vector& inputs, array& out) { + void Axpby::eval_cpu( + const std::vector& inputs, + const std::vector& outputs) { assert(inputs.size() == 2); auto& x = inputs[0]; auto& y = inputs[1]; + auto& out = outputs[0]; // Accelerate specialization for contiguous single precision float arrays if (out.dtype() == float32 && @@ -410,35 +382,33 @@ With this in mind, lets finally implement our :meth:`Axpby::eval_cpu`. return; } - // Fall back to common backend if specializations are not available - eval(inputs, out); + // Fall back to common back-end if specializations are not available + eval(inputs, outputs); } -We have now hit a milestone! Just this much is enough to run the operation -:meth:`axpby` on a CPU stream! +Just this much is enough to run the operation :meth:`axpby` on a CPU stream! If +you do not plan on running the operation on the GPU or using transforms on +computation graphs that contain :class:`Axpby`, you can stop implementing the +primitive here and enjoy the speed-ups you get from the Accelerate library. -If you do not plan on running the operation on the GPU or using transforms on -computation graphs that contain :class:`Axpby`, you can stop implementing the -primitive here and enjoy the speed-ups you get from the Accelerate library. - -Implementing the GPU Backend +Implementing the GPU Back-end ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Apple silicon devices address their GPUs using the Metal_ shading language, and -all GPU kernels in MLX are written using metal. +Apple silicon devices address their GPUs using the Metal_ shading language, and +GPU kernels in MLX are written using Metal. .. note:: - Here are some helpful resources if you are new to metal! + Here are some helpful resources if you are new to Metal: * A walkthrough of the metal compute pipeline: `Metal Example`_ * Documentation for metal shading language: `Metal Specification`_ * Using metal from C++: `Metal-cpp`_ -Let's keep the GPU algorithm simple. We will launch exactly as many threads -as there are elements in the output. Each thread will pick the element it needs -from ``x`` and ``y``, do the pointwise operation, and then update its assigned -element in the output. +Let's keep the GPU kernel simple. We will launch exactly as many threads as +there are elements in the output. Each thread will pick the element it needs +from ``x`` and ``y``, do the point-wise operation, and update its assigned +element in the output. .. code-block:: C++ @@ -457,15 +427,14 @@ element in the output. // Convert linear indices to offsets in array auto x_offset = elem_to_loc(index, shape, x_strides, ndim); auto y_offset = elem_to_loc(index, shape, y_strides, ndim); - + // Do the operation and update the output - out[index] = + out[index] = static_cast(alpha) * x[x_offset] + static_cast(beta) * y[y_offset]; } We then need to instantiate this template for all floating point types and give -each instantiation a unique host name so we can identify the right kernel for -each data type. +each instantiation a unique host name so we can identify it. .. code-block:: C++ @@ -488,29 +457,21 @@ each data type. instantiate_axpby(bfloat16, bfloat16_t); instantiate_axpby(complex64, complex64_t); -This kernel will be compiled into a metal library ``mlx_ext.metallib`` as we -will see later in :ref:`Building with CMake`. In the following example, we -assume that the library ``mlx_ext.metallib`` will always be co-located with -the executable/ shared-library calling the :meth:`register_library` function. -The :meth:`register_library` function takes the library's name and potential -path (or in this case, a function that can produce the path of the metal -library) and tries to load that library if it hasn't already been registered -by the relevant static :class:`mlx::core::metal::Device` object. This is why, -it is important to package your C++ library with the metal library. We will -go over this process in more detail later. - -The logic to determine the kernel, set the inputs, resolve the grid dimensions -and dispatch it to the GPU are contained in :meth:`Axpby::eval_gpu` as shown +The logic to determine the kernel, set the inputs, resolve the grid dimensions, +and dispatch to the GPU are contained in :meth:`Axpby::eval_gpu` as shown below. .. code-block:: C++ /** Evaluate primitive on GPU */ - void Axpby::eval_gpu(const std::vector& inputs, array& out) { + void Axpby::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { // Prepare inputs assert(inputs.size() == 2); auto& x = inputs[0]; auto& y = inputs[1]; + auto& out = outputs[0]; // Each primitive carries the stream it should execute on // and each stream carries its device identifiers @@ -518,10 +479,10 @@ below. // We get the needed metal device using the stream auto& d = metal::device(s.device); - // Allocate output memory + // Allocate output memory out.set_data(allocator::malloc_or_wait(out.nbytes())); - // Resolve name of kernel (corresponds to axpby.metal) + // Resolve name of kernel std::ostringstream kname; kname << "axpby_" << "general_" << type_to_name(out); @@ -552,7 +513,7 @@ below. compute_encoder->setBytes(&alpha_, sizeof(float), 3); compute_encoder->setBytes(&beta_, sizeof(float), 4); - // Encode shape, strides and ndim + // Encode shape, strides and ndim compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5); compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6); compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7); @@ -575,28 +536,25 @@ below. We can now call the :meth:`axpby` operation on both the CPU and the GPU! -A few things to note about MLX and metal before moving on. MLX keeps track -of the active ``compute_encoder``. We rely on :meth:`d.get_command_encoder` -to give us the active metal compute command encoder instead of building a -new one and calling :meth:`compute_encoder->end_encoding` at the end. -MLX keeps adding kernels (compute pipelines) to the active command encoder -until some specified limit is hit or the compute encoder needs to be flushed -for synchronization. MLX also handles enqueuing and committing the associated -command buffers as needed. We suggest taking a deeper dive into -:class:`metal::Device` if you would like to study this routine further. +A few things to note about MLX and Metal before moving on. MLX keeps track of +the active ``command_buffer`` and the ``MTLCommandBuffer`` to which it is +associated. We rely on :meth:`d.get_command_encoder` to give us the active +metal compute command encoder instead of building a new one and calling +:meth:`compute_encoder->end_encoding` at the end. MLX adds kernels (compute +pipelines) to the active command buffer until some specified limit is hit or +the command buffer needs to be flushed for synchronization. Primitive Transforms ^^^^^^^^^^^^^^^^^^^^^ -Now that we have come this far, let's also learn how to add implementations to -transformations in a :class:`Primitive`. These transformations can be built on -top of our operations, including the one we just defined now. Which then gives -us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations. +Next, let's add implementations for transformations in a :class:`Primitive`. +These transformations can be built on top of other operations, including the +one we just defined: .. code-block:: C++ /** The Jacobian-vector product. */ - array Axpby::jvp( + std::vector Axpby::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { @@ -611,12 +569,12 @@ us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations. if (argnums.size() > 1) { auto scale = argnums[0] == 0 ? alpha_ : beta_; auto scale_arr = array(scale, tangents[0].dtype()); - return multiply(scale_arr, tangents[0], stream()); + return {multiply(scale_arr, tangents[0], stream())}; } // If, argnums = {0, 1}, we take contributions from both // which gives us jvp = tangent_x * alpha + tangent_y * beta else { - return axpby(tangents[0], tangents[1], alpha_, beta_, stream()); + return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())}; } } @@ -625,34 +583,35 @@ us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations. /** The vector-Jacobian product. */ std::vector Axpby::vjp( const std::vector& primals, - const array& cotan, - const std::vector& argnums) { + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& /* unused */) { // Reverse mode diff std::vector vjps; for (auto arg : argnums) { auto scale = arg == 0 ? alpha_ : beta_; - auto scale_arr = array(scale, cotan.dtype()); - vjps.push_back(multiply(scale_arr, cotan, stream())); + auto scale_arr = array(scale, cotangents[0].dtype()); + vjps.push_back(multiply(scale_arr, cotangents[0], stream())); } return vjps; } -Finally, you need not have a transformation fully defined to start using your -own :class:`Primitive`. +Note, a transformation does not need to be fully defined to start using +the :class:`Primitive`. .. code-block:: C++ /** Vectorize primitive along given axis */ - std::pair Axpby::vmap( + std::pair, std::vector> Axpby::vmap( const std::vector& inputs, const std::vector& axes) { - throw std::runtime_error("Axpby has no vmap implementation."); + throw std::runtime_error("[Axpby] vmap not implemented."); } Building and Binding -------------------- -Let's look at the overall directory structure first. +Let's look at the overall directory structure first. | extensions | ├── axpby @@ -666,40 +625,39 @@ Let's look at the overall directory structure first. | └── setup.py * ``extensions/axpby/`` defines the C++ extension library -* ``extensions/mlx_sample_extensions`` sets out the structure for the - associated python package -* ``extensions/bindings.cpp`` provides python bindings for our operation -* ``extensions/CMakeLists.txt`` holds CMake rules to build the library and - python bindings +* ``extensions/mlx_sample_extensions`` sets out the structure for the + associated Python package +* ``extensions/bindings.cpp`` provides Python bindings for our operation +* ``extensions/CMakeLists.txt`` holds CMake rules to build the library and + Python bindings * ``extensions/setup.py`` holds the ``setuptools`` rules to build and install - the python package + the Python package Binding to Python ^^^^^^^^^^^^^^^^^^ -We use PyBind11_ to build a Python API for the C++ library. Since bindings for +We use nanobind_ to build a Python API for the C++ library. Since bindings for components such as :class:`mlx.core.array`, :class:`mlx.core.stream`, etc. are -already provided, adding our :meth:`axpby` is simple! +already provided, adding our :meth:`axpby` is simple. .. code-block:: C++ - PYBIND11_MODULE(mlx_sample_extensions, m) { - m.doc() = "Sample C++ and metal extensions for MLX"; + NB_MODULE(_ext, m) { + m.doc() = "Sample extension for MLX"; m.def( "axpby", &axpby, "x"_a, "y"_a, - py::pos_only(), "alpha"_a, "beta"_a, - py::kw_only(), - "stream"_a = py::none(), - R"pbdoc( + nb::kw_only(), + "stream"_a = nb::none(), + R"( Scale and sum two vectors element-wise ``z = alpha * x + beta * y`` - + Follows numpy style broadcasting between ``x`` and ``y`` Inputs are upcasted to floats if needed @@ -711,17 +669,17 @@ already provided, adding our :meth:`axpby` is simple! Returns: array: ``alpha * x + beta * y`` - )pbdoc"); + )"); } -Most of the complexity in the above example comes from additional bells and +Most of the complexity in the above example comes from additional bells and whistles such as the literal names and doc-strings. .. warning:: - :mod:`mlx.core` needs to be imported before importing - :mod:`mlx_sample_extensions` as defined by the pybind11 module above to - ensure that the casters for :mod:`mlx.core` components like + :mod:`mlx.core` must be imported before importing + :mod:`mlx_sample_extensions` as defined by the nanobind module above to + ensure that the casters for :mod:`mlx.core` components like :class:`mlx.core.array` are available. .. _Building with CMake: @@ -729,8 +687,8 @@ whistles such as the literal names and doc-strings. Building with CMake ^^^^^^^^^^^^^^^^^^^^ -Building the C++ extension library itself is simple, it only requires that you -``find_package(MLX CONFIG)`` and then link it to your library. +Building the C++ extension library only requires that you ``find_package(MLX +CONFIG)`` and then link it to your library. .. code-block:: cmake @@ -752,12 +710,12 @@ Building the C++ extension library itself is simple, it only requires that you # Link to mlx target_link_libraries(mlx_ext PUBLIC mlx) -We also need to build the attached metal library. For convenience, we provide a -:meth:`mlx_build_metallib` function that builds a ``.metallib`` target given -sources, headers, destinations, etc. (defined in ``cmake/extension.cmake`` and -automatically imported with MLX package). +We also need to build the attached Metal library. For convenience, we provide a +:meth:`mlx_build_metallib` function that builds a ``.metallib`` target given +sources, headers, destinations, etc. (defined in ``cmake/extension.cmake`` and +automatically imported with MLX package). -Here is what that looks like in practice! +Here is what that looks like in practice: .. code-block:: cmake @@ -779,27 +737,29 @@ Here is what that looks like in practice! endif() -Finally, we build the Pybind11_ bindings +Finally, we build the nanobind_ bindings .. code-block:: cmake - pybind11_add_module( - mlx_sample_extensions - ${CMAKE_CURRENT_LIST_DIR}/bindings.cpp + nanobind_add_module( + _ext + NB_STATIC STABLE_ABI LTO NOMINSIZE + NB_DOMAIN mlx + ${CMAKE_CURRENT_LIST_DIR}/bindings.cpp ) - target_link_libraries(mlx_sample_extensions PRIVATE mlx_ext) + target_link_libraries(_ext PRIVATE mlx_ext) if(BUILD_SHARED_LIBS) - target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path) + target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path) endif() Building with ``setuptools`` -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Once we have set out the CMake build rules as described above, we can use the -build utilities defined in :mod:`mlx.extension` for a simple build process. +build utilities defined in :mod:`mlx.extension`: -.. code-block:: python +.. code-block:: python from mlx import extension from setuptools import setup @@ -809,48 +769,50 @@ build utilities defined in :mod:`mlx.extension` for a simple build process. name="mlx_sample_extensions", version="0.0.0", description="Sample C++ and Metal extensions for MLX primitives.", - ext_modules=[extension.CMakeExtension("mlx_sample_extensions")], + ext_modules=[extension.CMakeExtension("mlx_sample_extensions._ext")], cmdclass={"build_ext": extension.CMakeBuild}, - packages = ["mlx_sample_extensions"], - package_dir = {"": "mlx_sample_extensions"}, - package_data = {"mlx_sample_extensions" : ["*.so", "*.dylib", "*.metallib"]}, + packages=["mlx_sample_extensions"], + package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]}, + extras_require={"dev":[]}, zip_safe=False, - python_requires=">=3.7", + python_requires=">=3.8", ) .. note:: We treat ``extensions/mlx_sample_extensions`` as the package directory even though it only contains a ``__init__.py`` to ensure the following: - - * :mod:`mlx.core` is always imported before importing :mod:`mlx_sample_extensions` - * The C++ extension library and the metal library are co-located with the python - bindings and copied together if the package is installed -You can build inplace for development using + * :mod:`mlx.core` must be imported before importing :mod:`_ext` + * The C++ extension library and the metal library are co-located with the python + bindings and copied together if the package is installed + +To build the package, first install the build dependencies with ``pip install +-r requirements.txt``. You can then build inplace for development using ``python setup.py build_ext -j8 --inplace`` (in ``extensions/``) -This will result in a directory structure as follows: +This results in the directory structure: | extensions | ├── mlx_sample_extensions | │ ├── __init__.py | │ ├── libmlx_ext.dylib # C++ extension library | │ ├── mlx_ext.metallib # Metal library -| │ └── mlx_sample_extensions.cpython-3x-darwin.so # Python Binding +| │ └── _ext.cpython-3x-darwin.so # Python Binding | ... -When you try to install using the command ``python -m pip install .`` -(in ``extensions/``), the package will be installed with the same structure as -``extensions/mlx_sample_extensions`` and the C++ and metal library will be -copied along with the python binding since they are specified as ``package_data``. +When you try to install using the command ``python -m pip install .`` (in +``extensions/``), the package will be installed with the same structure as +``extensions/mlx_sample_extensions`` and the C++ and Metal library will be +copied along with the Python binding since they are specified as +``package_data``. Usage ----- -After installing the extension as described above, you should be able to simply -import the python package and play with it as you would any other MLX operation! +After installing the extension as described above, you should be able to simply +import the Python package and play with it as you would any other MLX operation. -Let's looks at a simple script and it's results! +Let's look at a simple script and its results: .. code-block:: python @@ -874,12 +836,12 @@ Output: c correctness: True Results -^^^^^^^^^^^^^^^^ +^^^^^^^ -Let's run a quick benchmark and see how our new ``axpby`` operation compares -with the naive :meth:`simple_axpby` we defined at first on the CPU. +Let's run a quick benchmark and see how our new ``axpby`` operation compares +with the naive :meth:`simple_axpby` we first defined on the CPU. -.. code-block:: python +.. code-block:: python import mlx.core as mx from mlx_sample_extensions import axpby @@ -898,7 +860,7 @@ with the naive :meth:`simple_axpby` we defined at first on the CPU. alpha = 4.0 beta = 2.0 - mx.eval((x, y)) + mx.eval(x, y) def bench(f): # Warm up @@ -919,30 +881,23 @@ with the naive :meth:`simple_axpby` we defined at first on the CPU. print(f"Simple axpby: {simple_time:.3f} s | Custom axpby: {custom_time:.3f} s") -Results: - -.. code-block:: - - Simple axpby: 0.114 s | Custom axpby: 0.109 s - -We see some modest improvements right away! +The results are ``Simple axpby: 0.114 s | Custom axpby: 0.109 s``. We see +modest improvements right away! This operation is now good to be used to build other operations, in :class:`mlx.nn.Module` calls, and also as a part of graph transformations like -:meth:`grad`! +:meth:`grad`. Scripts ------- .. admonition:: Download the code - The full example code is available in `mlx `_. - -.. code: `https://github.com/ml-explore/mlx/tree/main/examples/extensions/`_ + The full example code is available in `mlx `_. .. _Accelerate: https://developer.apple.com/documentation/accelerate/blas?language=objc .. _Metal: https://developer.apple.com/documentation/metal?language=objc .. _Metal-cpp: https://developer.apple.com/metal/cpp/ .. _`Metal Specification`: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf .. _`Metal Example`: https://developer.apple.com/documentation/metal/performing_calculations_on_a_gpu?language=objc -.. _PyBind11: https://pybind11.readthedocs.io/en/stable/ +.. _nanobind: https://nanobind.readthedocs.io/en/latest/ diff --git a/examples/extensions/CMakeLists.txt b/examples/extensions/CMakeLists.txt index 79902cbd8..b58a51176 100644 --- a/examples/extensions/CMakeLists.txt +++ b/examples/extensions/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required(VERSION 3.27) -project(mlx_sample_extensions LANGUAGES CXX) +project(_ext LANGUAGES CXX) # ----------------------------- Setup ----------------------------- set(CMAKE_CXX_STANDARD 17) @@ -11,8 +11,12 @@ option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON) # ----------------------------- Dependencies ----------------------------- find_package(MLX CONFIG REQUIRED) -find_package(Python COMPONENTS Interpreter Development) -find_package(pybind11 CONFIG REQUIRED) +find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED) +execute_process( + COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir + OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR) +list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}") +find_package(nanobind CONFIG REQUIRED) # ----------------------------- Extensions ----------------------------- @@ -38,7 +42,6 @@ target_link_libraries(mlx_ext PUBLIC mlx) # Build metallib if(MLX_BUILD_METAL) - mlx_build_metallib( TARGET mlx_ext_metallib TITLE mlx_ext @@ -54,13 +57,15 @@ if(MLX_BUILD_METAL) endif() -# ----------------------------- Pybind ----------------------------- -pybind11_add_module( - mlx_sample_extensions +# ----------------------------- Python Bindings ----------------------------- +nanobind_add_module( + _ext + NB_STATIC STABLE_ABI LTO NOMINSIZE + NB_DOMAIN mlx ${CMAKE_CURRENT_LIST_DIR}/bindings.cpp ) -target_link_libraries(mlx_sample_extensions PRIVATE mlx_ext) +target_link_libraries(_ext PRIVATE mlx_ext) if(BUILD_SHARED_LIBS) - target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path) + target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path) endif() diff --git a/examples/extensions/README.md b/examples/extensions/README.md new file mode 100644 index 000000000..17582bc0f --- /dev/null +++ b/examples/extensions/README.md @@ -0,0 +1,18 @@ + +## Build the extensions + +``` +pip install -e . +``` + +For faster builds during development, you can also pre-install the requirements: + +``` +pip install -r requirements.txt +``` + +And then run: + +``` +python setup.py build_ext -j8 --inplace +``` diff --git a/examples/extensions/axpby/axpby.cpp b/examples/extensions/axpby/axpby.cpp index 43b3aedc9..bfd308e7c 100644 --- a/examples/extensions/axpby/axpby.cpp +++ b/examples/extensions/axpby/axpby.cpp @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include #include @@ -43,7 +43,7 @@ array axpby( auto promoted_dtype = promote_types(x.dtype(), y.dtype()); // Upcast to float32 for non-floating point inputs x and y - auto out_dtype = is_floating_point(promoted_dtype) + auto out_dtype = issubdtype(promoted_dtype, float32) ? promoted_dtype : promote_types(promoted_dtype, float32); @@ -106,12 +106,12 @@ void axpby_impl( /** Fall back implementation for evaluation on CPU */ void Axpby::eval( const std::vector& inputs, - std::vector& out_arr) { - auto out = out_arr[0]; + std::vector& outputs) { // Check the inputs (registered in the op while constructing the out array) assert(inputs.size() == 2); auto& x = inputs[0]; auto& y = inputs[1]; + auto& out = outputs[0]; // Dispatch to the correct dtype if (out.dtype() == float32) { @@ -150,11 +150,7 @@ void axpby_impl_accelerate( // The data in the output array is allocated to match the strides in y // such that x, y, and out are contiguous in the same mode and // no transposition is needed - out.set_data( - allocator::malloc_or_wait(y.data_size() * out.itemsize()), - y.data_size(), - y.strides(), - y.flags()); + out.set_data(allocator::malloc_or_wait(out.nbytes())); // We then copy over the elements using the contiguous vector specialization copy_inplace(y, out, CopyType::Vector); @@ -180,11 +176,11 @@ void axpby_impl_accelerate( /** Evaluate primitive on CPU using accelerate specializations */ void Axpby::eval_cpu( const std::vector& inputs, - std::vector& outarr) { - auto out = outarr[0]; + std::vector& outputs) { assert(inputs.size() == 2); auto& x = inputs[0]; auto& y = inputs[1]; + auto& out = outputs[0]; // Accelerate specialization for contiguous single precision float arrays if (out.dtype() == float32 && @@ -195,7 +191,7 @@ void Axpby::eval_cpu( } // Fall back to common backend if specializations are not available - eval(inputs, outarr); + eval(inputs, outputs); } #else // Accelerate not available @@ -203,8 +199,8 @@ void Axpby::eval_cpu( /** Evaluate primitive on CPU falling back to common backend */ void Axpby::eval_cpu( const std::vector& inputs, - std::vector& out) { - eval(inputs, out); + const std::vector& outputs) { + eval(inputs, outputs); } #endif @@ -218,12 +214,12 @@ void Axpby::eval_cpu( /** Evaluate primitive on GPU */ void Axpby::eval_gpu( const std::vector& inputs, - std::vector& outarr) { + std::vector& outputs) { // Prepare inputs - auto out = outarr[0]; assert(inputs.size() == 2); auto& x = inputs[0]; auto& y = inputs[1]; + auto& out = outputs[0]; // Each primitive carries the stream it should execute on // and each stream carries its device identifiers @@ -372,4 +368,4 @@ bool Axpby::is_equivalent(const Primitive& other) const { return alpha_ == r_other.alpha_ && beta_ == r_other.beta_; } -} // namespace mlx::core \ No newline at end of file +} // namespace mlx::core diff --git a/examples/extensions/axpby/axpby.h b/examples/extensions/axpby/axpby.h index 649d9600a..1fb705d61 100644 --- a/examples/extensions/axpby/axpby.h +++ b/examples/extensions/axpby/axpby.h @@ -42,9 +42,9 @@ class Axpby : public Primitive { * To avoid unnecessary allocations, the evaluation function * is responsible for allocating space for the array. */ - void eval_cpu(const std::vector& inputs, std::vector& out) + void eval_cpu(const std::vector& inputs, std::vector& outputs) override; - void eval_gpu(const std::vector& inputs, std::vector& out) + void eval_gpu(const std::vector& inputs, std::vector& outputs) override; /** The Jacobian-vector product. */ @@ -83,7 +83,7 @@ class Axpby : public Primitive { float beta_; /** Fall back implementation for evaluation on CPU */ - void eval(const std::vector& inputs, std::vector& out); + void eval(const std::vector& inputs, std::vector& outputs); }; -} // namespace mlx::core \ No newline at end of file +} // namespace mlx::core diff --git a/examples/extensions/bindings.cpp b/examples/extensions/bindings.cpp index d05e6b636..bd801b31e 100644 --- a/examples/extensions/bindings.cpp +++ b/examples/extensions/bindings.cpp @@ -1,31 +1,31 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. -#include -#include +#include +#include #include "axpby/axpby.h" -namespace py = pybind11; -using namespace py::literals; +namespace nb = nanobind; +using namespace nb::literals; + using namespace mlx::core; -PYBIND11_MODULE(mlx_sample_extensions, m) { - m.doc() = "Sample C++ and metal extensions for MLX"; +NB_MODULE(_ext, m) { + m.doc() = "Sample extension for MLX"; m.def( "axpby", &axpby, "x"_a, "y"_a, - py::pos_only(), "alpha"_a, "beta"_a, - py::kw_only(), - "stream"_a = py::none(), - R"pbdoc( + nb::kw_only(), + "stream"_a = nb::none(), + R"( Scale and sum two vectors element-wise ``z = alpha * x + beta * y`` - + Follows numpy style broadcasting between ``x`` and ``y`` Inputs are upcasted to floats if needed @@ -37,5 +37,5 @@ PYBIND11_MODULE(mlx_sample_extensions, m) { Returns: array: ``alpha * x + beta * y`` - )pbdoc"); -} \ No newline at end of file + )"); +} diff --git a/examples/extensions/pyproject.toml b/examples/extensions/pyproject.toml index 1c5302936..c71470da1 100644 --- a/examples/extensions/pyproject.toml +++ b/examples/extensions/pyproject.toml @@ -1,3 +1,8 @@ [build-system] -requires = ["setuptools>=42", "pybind11>=2.10", "cmake>=3.24", "mlx @ git+https://github.com/mlx-explore/mlx@main"] -build-backend = "setuptools.build_meta" \ No newline at end of file +requires = [ + "setuptools>=42", + "cmake>=3.24", + "mlx>=0.9.0", + "nanobind@git+https://github.com/wjakob/nanobind.git#egg=4148debcf91f5ccab0c3b8d67b5c3cabd61f407f", +] +build-backend = "setuptools.build_meta" diff --git a/examples/extensions/requirements.txt b/examples/extensions/requirements.txt new file mode 100644 index 000000000..01a7d3864 --- /dev/null +++ b/examples/extensions/requirements.txt @@ -0,0 +1,4 @@ +setuptools>=42 +cmake>=3.24 +mlx>=0.9.0 +nanobind@git+https://github.com/wjakob/nanobind.git#egg=4148debcf91f5ccab0c3b8d67b5c3cabd61f407f diff --git a/examples/extensions/setup.py b/examples/extensions/setup.py index d432f67f7..ab6a3c5f3 100644 --- a/examples/extensions/setup.py +++ b/examples/extensions/setup.py @@ -1,4 +1,4 @@ -# Copyright © 2023 Apple Inc. +# Copyright © 2023-2024 Apple Inc. from setuptools import setup @@ -9,11 +9,11 @@ if __name__ == "__main__": name="mlx_sample_extensions", version="0.0.0", description="Sample C++ and Metal extensions for MLX primitives.", - ext_modules=[extension.CMakeExtension("mlx_sample_extensions")], + ext_modules=[extension.CMakeExtension("mlx_sample_extensions._ext")], cmdclass={"build_ext": extension.CMakeBuild}, packages=["mlx_sample_extensions"], - package_dir={"": "."}, package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]}, + extras_require={"dev": []}, zip_safe=False, python_requires=">=3.8", )