mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 18:48:15 +08:00 
			
		
		
		
	rebase
This commit is contained in:
		
							
								
								
									
										231
									
								
								docs/build/html/_sources/dev/extensions.rst
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										231
									
								
								docs/build/html/_sources/dev/extensions.rst
									
									
									
									
										vendored
									
									
								
							@@ -22,12 +22,12 @@ You can do that in MLX directly:
 | 
			
		||||
This function performs that operation while leaving the implementation and
 | 
			
		||||
function transformations to MLX.
 | 
			
		||||
 | 
			
		||||
However you may need to customize the underlying implementation, perhaps to
 | 
			
		||||
make it faster or for custom differentiation. In this tutorial we will go
 | 
			
		||||
through adding custom extensions. It will cover:
 | 
			
		||||
However, you may want to customize the underlying implementation, perhaps to
 | 
			
		||||
make it faster. In this tutorial we will go through adding custom extensions.
 | 
			
		||||
It will cover:
 | 
			
		||||
 | 
			
		||||
* The structure of the MLX library.
 | 
			
		||||
* Implementing a CPU operation that redirects to Accelerate_ when appropriate.
 | 
			
		||||
* Implementing a CPU operation.
 | 
			
		||||
* Implementing a GPU operation using metal.
 | 
			
		||||
* Adding the ``vjp`` and ``jvp`` function transformation.
 | 
			
		||||
* Building a custom extension and binding it to python.
 | 
			
		||||
@@ -45,7 +45,7 @@ Operations
 | 
			
		||||
Operations are the front-end functions that operate on arrays. They are defined
 | 
			
		||||
in the C++ API (:ref:`cpp_ops`), and the Python API (:ref:`ops`) binds them.
 | 
			
		||||
 | 
			
		||||
We would like an operation, :meth:`axpby` that takes in two arrays ``x`` and
 | 
			
		||||
We would like an operation :meth:`axpby` that takes in two arrays, ``x`` and
 | 
			
		||||
``y``, and two scalars, ``alpha`` and ``beta``. This is how to define it in
 | 
			
		||||
C++:
 | 
			
		||||
 | 
			
		||||
@@ -55,7 +55,7 @@ C++:
 | 
			
		||||
    *  Scale and sum two vectors element-wise
 | 
			
		||||
    *  z = alpha * x + beta * y
 | 
			
		||||
    *
 | 
			
		||||
    *  Follow numpy style broadcasting between x and y
 | 
			
		||||
    *  Use NumPy-style broadcasting between x and y
 | 
			
		||||
    *  Inputs are upcasted to floats if needed
 | 
			
		||||
    **/
 | 
			
		||||
    array axpby(
 | 
			
		||||
@@ -66,7 +66,7 @@ C++:
 | 
			
		||||
        StreamOrDevice s = {} // Stream on which to schedule the operation
 | 
			
		||||
    );
 | 
			
		||||
 | 
			
		||||
The simplest way to this operation is in terms of existing operations:
 | 
			
		||||
The simplest way to implement this is with existing operations:
 | 
			
		||||
 | 
			
		||||
.. code-block:: C++
 | 
			
		||||
 | 
			
		||||
@@ -153,9 +153,6 @@ more concrete:
 | 
			
		||||
      private:
 | 
			
		||||
        float alpha_;
 | 
			
		||||
        float beta_;
 | 
			
		||||
 | 
			
		||||
        /** Fall back implementation for evaluation on CPU */
 | 
			
		||||
        void eval(const std::vector<array>& inputs, array& out);
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
The :class:`Axpby` class derives from the base :class:`Primitive` class. The
 | 
			
		||||
@@ -188,7 +185,7 @@ Let's reimplement our operation now in terms of our :class:`Axpby` primitive.
 | 
			
		||||
        auto promoted_dtype = promote_types(x.dtype(), y.dtype());
 | 
			
		||||
 | 
			
		||||
        // Upcast to float32 for non-floating point inputs x and y
 | 
			
		||||
        auto out_dtype = is_floating_point(promoted_dtype)
 | 
			
		||||
        auto out_dtype = issubdtype(promoted_dtype, float32)
 | 
			
		||||
            ? promoted_dtype
 | 
			
		||||
            : promote_types(promoted_dtype, float32);
 | 
			
		||||
 | 
			
		||||
@@ -234,49 +231,59 @@ the execution of the computation graph, and calls :meth:`Axpby::eval_cpu` or
 | 
			
		||||
Implementing the CPU Back-end
 | 
			
		||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 | 
			
		||||
 | 
			
		||||
Let's start by implementing a naive and generic version of
 | 
			
		||||
:meth:`Axpby::eval_cpu`. We declared this as a private member function of
 | 
			
		||||
:class:`Axpby` earlier called :meth:`Axpby::eval`.
 | 
			
		||||
Let's start by implementing :meth:`Axpby::eval_cpu`.
 | 
			
		||||
 | 
			
		||||
Our naive method will go over each element of the output array, find the
 | 
			
		||||
The method will go over each element of the output array, find the
 | 
			
		||||
corresponding input elements of ``x`` and ``y`` and perform the operation
 | 
			
		||||
point-wise. This is captured in the templated function :meth:`axpby_impl`.
 | 
			
		||||
 | 
			
		||||
.. code-block:: C++
 | 
			
		||||
 | 
			
		||||
    template <typename T>
 | 
			
		||||
    void axpby_impl(
 | 
			
		||||
            const array& x,
 | 
			
		||||
            const array& y,
 | 
			
		||||
            array& out,
 | 
			
		||||
            float alpha_,
 | 
			
		||||
            float beta_) {
 | 
			
		||||
        // We only allocate memory when we are ready to fill the output
 | 
			
		||||
        // malloc_or_wait synchronously allocates available memory
 | 
			
		||||
        // There may be a wait executed here if the allocation is requested
 | 
			
		||||
        // under memory-pressured conditions
 | 
			
		||||
        out.set_data(allocator::malloc_or_wait(out.nbytes()));
 | 
			
		||||
  template <typename T>
 | 
			
		||||
  void axpby_impl(
 | 
			
		||||
      const mx::array& x,
 | 
			
		||||
      const mx::array& y,
 | 
			
		||||
      mx::array& out,
 | 
			
		||||
      float alpha_,
 | 
			
		||||
      float beta_,
 | 
			
		||||
      mx::Stream stream) {
 | 
			
		||||
    // Allocate the output with `malloc_or_wait` which synchronously allocates
 | 
			
		||||
    // memory, potentially waiting if the system is under memory pressure
 | 
			
		||||
    out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
 | 
			
		||||
 | 
			
		||||
        // Collect input and output data pointers
 | 
			
		||||
        const T* x_ptr = x.data<T>();
 | 
			
		||||
        const T* y_ptr = y.data<T>();
 | 
			
		||||
        T* out_ptr = out.data<T>();
 | 
			
		||||
    // Get the CPU command encoder and register input and output arrays
 | 
			
		||||
    auto& encoder = mx::cpu::get_command_encoder(stream);
 | 
			
		||||
    encoder.set_input_array(x);
 | 
			
		||||
    encoder.set_input_array(y);
 | 
			
		||||
    encoder.set_output_array(out);
 | 
			
		||||
 | 
			
		||||
        // Cast alpha and beta to the relevant types
 | 
			
		||||
        T alpha = static_cast<T>(alpha_);
 | 
			
		||||
        T beta = static_cast<T>(beta_);
 | 
			
		||||
    // Launch the CPU kernel
 | 
			
		||||
    encoder.dispatch([x_ptr = x.data<T>(),
 | 
			
		||||
                      y_ptr = y.data<T>(),
 | 
			
		||||
                      out_ptr = out.data<T>(),
 | 
			
		||||
                      size = out.size(),
 | 
			
		||||
                      shape = out.shape(),
 | 
			
		||||
                      x_strides = x.strides(),
 | 
			
		||||
                      y_strides = y.strides(),
 | 
			
		||||
                      alpha_,
 | 
			
		||||
                      beta_]() {
 | 
			
		||||
 | 
			
		||||
        // Do the element-wise operation for each output
 | 
			
		||||
        for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
 | 
			
		||||
            // Map linear indices to offsets in x and y
 | 
			
		||||
            auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
 | 
			
		||||
            auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
 | 
			
		||||
      // Cast alpha and beta to the relevant types
 | 
			
		||||
      T alpha = static_cast<T>(alpha_);
 | 
			
		||||
      T beta = static_cast<T>(beta_);
 | 
			
		||||
 | 
			
		||||
            // We allocate the output to be contiguous and regularly strided
 | 
			
		||||
            // (defaults to row major) and hence it doesn't need additional mapping
 | 
			
		||||
            out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
      // Do the element-wise operation for each output
 | 
			
		||||
      for (size_t out_idx = 0; out_idx < size; out_idx++) {
 | 
			
		||||
        // Map linear indices to offsets in x and y
 | 
			
		||||
        auto x_offset = mx::elem_to_loc(out_idx, shape, x_strides);
 | 
			
		||||
        auto y_offset = mx::elem_to_loc(out_idx, shape, y_strides);
 | 
			
		||||
 | 
			
		||||
        // We allocate the output to be contiguous and regularly strided
 | 
			
		||||
        // (defaults to row major) and hence it doesn't need additional mapping
 | 
			
		||||
        out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
 | 
			
		||||
      }
 | 
			
		||||
    });
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
Our implementation should work for all incoming floating point arrays.
 | 
			
		||||
Accordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and
 | 
			
		||||
@@ -284,112 +291,32 @@ Accordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and
 | 
			
		||||
 | 
			
		||||
.. code-block:: C++
 | 
			
		||||
 | 
			
		||||
    /** Fall back implementation for evaluation on CPU */
 | 
			
		||||
    void Axpby::eval(
 | 
			
		||||
      const std::vector<array>& inputs,
 | 
			
		||||
      const std::vector<array>& outputs) {
 | 
			
		||||
        auto& x = inputs[0];
 | 
			
		||||
        auto& y = inputs[1];
 | 
			
		||||
        auto& out = outputs[0];
 | 
			
		||||
 | 
			
		||||
        // Dispatch to the correct dtype
 | 
			
		||||
        if (out.dtype() == float32) {
 | 
			
		||||
            return axpby_impl<float>(x, y, out, alpha_, beta_);
 | 
			
		||||
        } else if (out.dtype() == float16) {
 | 
			
		||||
            return axpby_impl<float16_t>(x, y, out, alpha_, beta_);
 | 
			
		||||
        } else if (out.dtype() == bfloat16) {
 | 
			
		||||
            return axpby_impl<bfloat16_t>(x, y, out, alpha_, beta_);
 | 
			
		||||
        } else if (out.dtype() == complex64) {
 | 
			
		||||
            return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
 | 
			
		||||
        } else {
 | 
			
		||||
            throw std::runtime_error(
 | 
			
		||||
                "[Axpby] Only supports floating point types.");
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
This is good as a fallback implementation. We can use the ``axpby`` routine
 | 
			
		||||
provided by the Accelerate_ framework for a faster implementation in certain
 | 
			
		||||
cases:
 | 
			
		||||
 | 
			
		||||
#.  Accelerate does not provide implementations of ``axpby`` for half precision
 | 
			
		||||
    floats. We can only use it for ``float32`` types.
 | 
			
		||||
#.  Accelerate assumes the inputs ``x`` and ``y`` are contiguous and all
 | 
			
		||||
    elements have fixed strides between them. We only direct to Accelerate
 | 
			
		||||
    if both ``x`` and ``y`` are row contiguous or column contiguous.
 | 
			
		||||
#.  Accelerate performs the routine ``Y = (alpha * X) + (beta * Y)`` in-place.
 | 
			
		||||
    MLX expects to write the output to a new array. We must copy the elements
 | 
			
		||||
    of ``y`` into the output and use that as an input to ``axpby``.
 | 
			
		||||
 | 
			
		||||
Let's write an implementation that uses Accelerate in the right conditions.
 | 
			
		||||
It allocates data for the output, copies ``y`` into it, and then calls the
 | 
			
		||||
:func:`catlas_saxpby` from accelerate.
 | 
			
		||||
 | 
			
		||||
.. code-block:: C++
 | 
			
		||||
 | 
			
		||||
    template <typename T>
 | 
			
		||||
    void axpby_impl_accelerate(
 | 
			
		||||
            const array& x,
 | 
			
		||||
            const array& y,
 | 
			
		||||
            array& out,
 | 
			
		||||
            float alpha_,
 | 
			
		||||
            float beta_) {
 | 
			
		||||
        // Accelerate library provides catlas_saxpby which does
 | 
			
		||||
        // Y = (alpha * X) + (beta * Y) in place
 | 
			
		||||
        // To use it, we first copy the data in y over to the output array
 | 
			
		||||
        out.set_data(allocator::malloc_or_wait(out.nbytes()));
 | 
			
		||||
 | 
			
		||||
        // We then copy over the elements using the contiguous vector specialization
 | 
			
		||||
        copy_inplace(y, out, CopyType::Vector);
 | 
			
		||||
 | 
			
		||||
        // Get x and y pointers for catlas_saxpby
 | 
			
		||||
        const T* x_ptr = x.data<T>();
 | 
			
		||||
        T* y_ptr = out.data<T>();
 | 
			
		||||
 | 
			
		||||
        T alpha = static_cast<T>(alpha_);
 | 
			
		||||
        T beta = static_cast<T>(beta_);
 | 
			
		||||
 | 
			
		||||
        // Call the inplace accelerate operator
 | 
			
		||||
        catlas_saxpby(
 | 
			
		||||
            /* N = */ out.size(),
 | 
			
		||||
            /* ALPHA = */ alpha,
 | 
			
		||||
            /* X = */ x_ptr,
 | 
			
		||||
            /* INCX = */ 1,
 | 
			
		||||
            /* BETA = */ beta,
 | 
			
		||||
            /* Y = */ y_ptr,
 | 
			
		||||
            /* INCY = */ 1);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
For inputs that do not fit the criteria for accelerate, we fall back to
 | 
			
		||||
:meth:`Axpby::eval`. With this in mind, let's finish our
 | 
			
		||||
:meth:`Axpby::eval_cpu`.
 | 
			
		||||
 | 
			
		||||
.. code-block:: C++
 | 
			
		||||
 | 
			
		||||
    /** Evaluate primitive on CPU using accelerate specializations */
 | 
			
		||||
    void Axpby::eval_cpu(
 | 
			
		||||
      const std::vector<array>& inputs,
 | 
			
		||||
      const std::vector<array>& outputs) {
 | 
			
		||||
        assert(inputs.size() == 2);
 | 
			
		||||
        auto& x = inputs[0];
 | 
			
		||||
        auto& y = inputs[1];
 | 
			
		||||
        auto& out = outputs[0];
 | 
			
		||||
        const std::vector<mx::array>& inputs,
 | 
			
		||||
        std::vector<mx::array>& outputs) {
 | 
			
		||||
      auto& x = inputs[0];
 | 
			
		||||
      auto& y = inputs[1];
 | 
			
		||||
      auto& out = outputs[0];
 | 
			
		||||
 | 
			
		||||
        // Accelerate specialization for contiguous single precision float arrays
 | 
			
		||||
        if (out.dtype() == float32 &&
 | 
			
		||||
            ((x.flags().row_contiguous && y.flags().row_contiguous) ||
 | 
			
		||||
            (x.flags().col_contiguous && y.flags().col_contiguous))) {
 | 
			
		||||
            axpby_impl_accelerate<float>(x, y, out, alpha_, beta_);
 | 
			
		||||
            return;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // Fall back to common back-end if specializations are not available
 | 
			
		||||
        eval(inputs, outputs);
 | 
			
		||||
      // Dispatch to the correct dtype
 | 
			
		||||
      if (out.dtype() == mx::float32) {
 | 
			
		||||
        return axpby_impl<float>(x, y, out, alpha_, beta_, stream());
 | 
			
		||||
      } else if (out.dtype() == mx::float16) {
 | 
			
		||||
        return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_, stream());
 | 
			
		||||
      } else if (out.dtype() == mx::bfloat16) {
 | 
			
		||||
        return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_, stream());
 | 
			
		||||
      } else if (out.dtype() == mx::complex64) {
 | 
			
		||||
        return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_, stream());
 | 
			
		||||
      } else {
 | 
			
		||||
        throw std::runtime_error(
 | 
			
		||||
            "Axpby is only supported for floating point types.");
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
Just this much is enough to run the operation :meth:`axpby` on a CPU stream! If
 | 
			
		||||
you do not plan on running the operation on the GPU or using transforms on
 | 
			
		||||
computation graphs that contain :class:`Axpby`, you can stop implementing the
 | 
			
		||||
primitive here and enjoy the speed-ups you get from the Accelerate library.
 | 
			
		||||
primitive here.
 | 
			
		||||
 | 
			
		||||
Implementing the GPU Back-end
 | 
			
		||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 | 
			
		||||
@@ -824,7 +751,7 @@ Results
 | 
			
		||||
^^^^^^^
 | 
			
		||||
 | 
			
		||||
Let's run a quick benchmark and see how our new ``axpby`` operation compares
 | 
			
		||||
with the naive :meth:`simple_axpby` we first defined on the CPU.
 | 
			
		||||
with the naive :meth:`simple_axpby` we first defined.
 | 
			
		||||
 | 
			
		||||
.. code-block:: python
 | 
			
		||||
 | 
			
		||||
@@ -832,13 +759,11 @@ with the naive :meth:`simple_axpby` we first defined on the CPU.
 | 
			
		||||
    from mlx_sample_extensions import axpby
 | 
			
		||||
    import time
 | 
			
		||||
 | 
			
		||||
    mx.set_default_device(mx.cpu)
 | 
			
		||||
 | 
			
		||||
    def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
 | 
			
		||||
        return alpha * x + beta * y
 | 
			
		||||
 | 
			
		||||
    M = 256
 | 
			
		||||
    N = 512
 | 
			
		||||
    M = 4096
 | 
			
		||||
    N = 4096
 | 
			
		||||
 | 
			
		||||
    x = mx.random.normal((M, N))
 | 
			
		||||
    y = mx.random.normal((M, N))
 | 
			
		||||
@@ -849,24 +774,24 @@ with the naive :meth:`simple_axpby` we first defined on the CPU.
 | 
			
		||||
 | 
			
		||||
    def bench(f):
 | 
			
		||||
        # Warm up
 | 
			
		||||
        for i in range(100):
 | 
			
		||||
        for i in range(5):
 | 
			
		||||
            z = f(x, y, alpha, beta)
 | 
			
		||||
            mx.eval(z)
 | 
			
		||||
 | 
			
		||||
        # Timed run
 | 
			
		||||
        s = time.time()
 | 
			
		||||
        for i in range(5000):
 | 
			
		||||
        for i in range(100):
 | 
			
		||||
            z = f(x, y, alpha, beta)
 | 
			
		||||
            mx.eval(z)
 | 
			
		||||
        e = time.time()
 | 
			
		||||
        return e - s
 | 
			
		||||
        return 1000 * (e - s) / 100
 | 
			
		||||
 | 
			
		||||
    simple_time = bench(simple_axpby)
 | 
			
		||||
    custom_time = bench(axpby)
 | 
			
		||||
 | 
			
		||||
    print(f"Simple axpby: {simple_time:.3f} s | Custom axpby: {custom_time:.3f} s")
 | 
			
		||||
    print(f"Simple axpby: {simple_time:.3f} ms | Custom axpby: {custom_time:.3f} ms")
 | 
			
		||||
 | 
			
		||||
The results are ``Simple axpby: 0.114 s | Custom axpby: 0.109 s``. We see
 | 
			
		||||
The results are ``Simple axpby: 1.559 ms | Custom axpby: 0.774 ms``. We see
 | 
			
		||||
modest improvements right away!
 | 
			
		||||
 | 
			
		||||
This operation is now good to be used to build other operations, in
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user