mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Custom Metal Kernels from Python (#1325)
* start * simple kernels working * restructure * inverse example working * docs + fixes * missing file * fix imports * address comments * add docs + fix test * Review comments + refactor to a single function * update docs * remove hashing * fix contig bug in test * back to a class * trailing whitespace * fix tests * match c++ and python apis * add link + make args kw_only
This commit is contained in:
		@@ -1,9 +1,14 @@
 | 
			
		||||
// Copyright © 2023-2024 Apple Inc.
 | 
			
		||||
 | 
			
		||||
#include <nanobind/nanobind.h>
 | 
			
		||||
#include <nanobind/stl/map.h>
 | 
			
		||||
#include <nanobind/stl/optional.h>
 | 
			
		||||
#include <nanobind/stl/string.h>
 | 
			
		||||
#include <nanobind/stl/tuple.h>
 | 
			
		||||
#include <nanobind/stl/variant.h>
 | 
			
		||||
#include <nanobind/stl/vector.h>
 | 
			
		||||
 | 
			
		||||
#include "python/src/utils.h"
 | 
			
		||||
 | 
			
		||||
#include "mlx/fast.h"
 | 
			
		||||
#include "mlx/ops.h"
 | 
			
		||||
@@ -186,4 +191,136 @@ void init_fast(nb::module_& parent_module) {
 | 
			
		||||
        Returns:
 | 
			
		||||
          array: The quantized version of ``w``
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
 | 
			
		||||
  nb::class_<fast::MetalKernel>(
 | 
			
		||||
      m,
 | 
			
		||||
      "metal_kernel",
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
      A jit-compiled custom Metal kernel defined from a source string.
 | 
			
		||||
      )pbdoc")
 | 
			
		||||
      .def(
 | 
			
		||||
          nb::init<const std::string&, const std::string&, bool>(),
 | 
			
		||||
          "name"_a,
 | 
			
		||||
          "source"_a,
 | 
			
		||||
          "ensure_row_contiguous"_a = true,
 | 
			
		||||
          R"pbdoc(
 | 
			
		||||
      Initialize a metal_kernel.
 | 
			
		||||
 | 
			
		||||
      Args:
 | 
			
		||||
        name (str): Name for the kernel.
 | 
			
		||||
        source (str): Source code. This is the body of a function in Metal,
 | 
			
		||||
            the function signature will be generated for you. The names of the inputs/outputs
 | 
			
		||||
            are determined by the ``inputs`` and ``output_shapes``/``output_dtypes``
 | 
			
		||||
            used when the kernel is called.
 | 
			
		||||
        ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous
 | 
			
		||||
            before the kernel runs. Default: ``True``.
 | 
			
		||||
      Returns:
 | 
			
		||||
        Callable ``metal_kernel``.
 | 
			
		||||
 | 
			
		||||
      .. code-block:: python
 | 
			
		||||
 | 
			
		||||
        def exp_elementwise(a: mx.array):
 | 
			
		||||
            source = """
 | 
			
		||||
                uint elem = thread_position_in_grid.x;
 | 
			
		||||
                T tmp = inp[elem];
 | 
			
		||||
                out[elem] = metal::exp(tmp);
 | 
			
		||||
            """
 | 
			
		||||
 | 
			
		||||
            kernel = mx.fast.metal_kernel(
 | 
			
		||||
                name="myexp",
 | 
			
		||||
                source=source
 | 
			
		||||
            )
 | 
			
		||||
            outputs = kernel(
 | 
			
		||||
                inputs={"inp": a},
 | 
			
		||||
                template={"T": mx.float32},
 | 
			
		||||
                grid=(a.size, 1, 1),
 | 
			
		||||
                threadgroup=(256, 1, 1),
 | 
			
		||||
                output_shapes={"out": a.shape},
 | 
			
		||||
                output_dtypes={"out": a.dtype},
 | 
			
		||||
                verbose=True,
 | 
			
		||||
            )
 | 
			
		||||
            return outputs["out"]
 | 
			
		||||
 | 
			
		||||
        a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
 | 
			
		||||
        b = exp_elementwise(a)
 | 
			
		||||
        assert mx.allclose(b, mx.exp(a))
 | 
			
		||||
 | 
			
		||||
      )pbdoc")
 | 
			
		||||
      .def(
 | 
			
		||||
          "__call__",
 | 
			
		||||
          [](fast::MetalKernel& kernel,
 | 
			
		||||
             std::map<std::string, ScalarOrArray>& inputs_,
 | 
			
		||||
             std::map<std::string, std::vector<int>>& output_shapes,
 | 
			
		||||
             std::map<std::string, Dtype>& output_dtypes,
 | 
			
		||||
             std::tuple<int, int, int> grid,
 | 
			
		||||
             std::tuple<int, int, int> threadgroup,
 | 
			
		||||
             std::optional<std::map<std::string, nb::handle>> template_args_,
 | 
			
		||||
             bool verbose,
 | 
			
		||||
             StreamOrDevice s) {
 | 
			
		||||
            std::map<std::string, array> inputs;
 | 
			
		||||
            for (const auto& [name, value] : inputs_) {
 | 
			
		||||
              auto arr = to_array(value, std::nullopt);
 | 
			
		||||
              inputs.insert({name, arr});
 | 
			
		||||
            }
 | 
			
		||||
            std::map<std::string, fast::TemplateArg> template_args;
 | 
			
		||||
            if (template_args_) {
 | 
			
		||||
              for (const auto& [name, value] : template_args_.value()) {
 | 
			
		||||
                // Handle bool, int and dtype template args
 | 
			
		||||
                if (nb::isinstance<bool>(value)) {
 | 
			
		||||
                  bool bool_val = nb::cast<bool>(value);
 | 
			
		||||
                  template_args.insert({name, bool_val});
 | 
			
		||||
                } else if (nb::isinstance<int>(value)) {
 | 
			
		||||
                  int int_val = nb::cast<int>(value);
 | 
			
		||||
                  template_args.insert({name, int_val});
 | 
			
		||||
                } else if (nb::isinstance<Dtype>(value)) {
 | 
			
		||||
                  Dtype dtype = nb::cast<Dtype>(value);
 | 
			
		||||
                  template_args.insert({name, dtype});
 | 
			
		||||
                } else {
 | 
			
		||||
                  throw std::invalid_argument(
 | 
			
		||||
                      "[[metal_kernel]] Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`.");
 | 
			
		||||
                }
 | 
			
		||||
              }
 | 
			
		||||
            }
 | 
			
		||||
            return kernel(
 | 
			
		||||
                inputs,
 | 
			
		||||
                output_shapes,
 | 
			
		||||
                output_dtypes,
 | 
			
		||||
                grid,
 | 
			
		||||
                threadgroup,
 | 
			
		||||
                template_args,
 | 
			
		||||
                verbose,
 | 
			
		||||
                s);
 | 
			
		||||
          },
 | 
			
		||||
          nb::kw_only(),
 | 
			
		||||
          "inputs"_a,
 | 
			
		||||
          "output_shapes"_a,
 | 
			
		||||
          "output_dtypes"_a,
 | 
			
		||||
          "grid"_a,
 | 
			
		||||
          "threadgroup"_a,
 | 
			
		||||
          "template"_a = nb::none(),
 | 
			
		||||
          "verbose"_a = false,
 | 
			
		||||
          "stream"_a = nb::none(),
 | 
			
		||||
          nb::sig(
 | 
			
		||||
              "def __call__(self, *, inputs: Mapping[str, Union[scalar, array]], output_shapes: Mapping[str, Sequence[int]], output_dtypes: Mapping[str, Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[Mapping[str, Union[bool, int, Dtype]]] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"),
 | 
			
		||||
          R"pbdoc(
 | 
			
		||||
            Run the kernel.
 | 
			
		||||
 | 
			
		||||
            Args:
 | 
			
		||||
              inputs (Mapping[str, array]): Inputs. These will be added to the function signature and passed to the Metal kernel.
 | 
			
		||||
                  The keys will be the names of the arguments to the kernel.
 | 
			
		||||
              output_shapes (Mapping[str, Sequence[int]]): Output shapes. A dict mapping
 | 
			
		||||
                  output variable names to shapes. These will be added to the function signature.
 | 
			
		||||
              output_dtypes (Mapping[str, Dtype]): Output dtypes. A dict mapping output variable
 | 
			
		||||
                  names to dtypes. Must have the same keys as ``output_shapes``.
 | 
			
		||||
              grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.
 | 
			
		||||
              threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.
 | 
			
		||||
              template (Mapping[str, Union[bool, int, Dtype]], optional): Template arguments.
 | 
			
		||||
                  These will be added as template arguments to the kernel definition.
 | 
			
		||||
              verbose (bool, optional): Whether to print the full generated source code of the kernel
 | 
			
		||||
                  when it is run.
 | 
			
		||||
              stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``.
 | 
			
		||||
 | 
			
		||||
            Returns:
 | 
			
		||||
              dict[str, array]: Dictionary of output arrays based on ``output_shapes``/``output_dtypes``.
 | 
			
		||||
            )pbdoc");
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -325,9 +325,9 @@ void init_linalg(nb::module_& parent_module) {
 | 
			
		||||
      nb::sig(
 | 
			
		||||
          "def cholesky_inv(L: array, upper: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        Compute the inverse of a real symmetric positive semi-definite matrix using it's Cholesky decomposition L.
 | 
			
		||||
        Compute the inverse of a real symmetric positive semi-definite matrix using it's Cholesky decomposition.
 | 
			
		||||
 | 
			
		||||
        Let A be a real symmetric positive semi-definite matrix and L its Cholesky definition such that:
 | 
			
		||||
        Let :math:`\mathbf{A}` be a real symmetric positive semi-definite matrix and :math:`\mathbf{L}` its Cholesky decomposition such that:
 | 
			
		||||
 | 
			
		||||
        .. math::
 | 
			
		||||
 | 
			
		||||
@@ -339,7 +339,7 @@ void init_linalg(nb::module_& parent_module) {
 | 
			
		||||
 | 
			
		||||
        This function supports arrays with at least 2 dimensions. When the input
 | 
			
		||||
        has more than two dimensions, the Cholesky inverse is computed for each matrix
 | 
			
		||||
        in the last two dimensions of ``L``.
 | 
			
		||||
        in the last two dimensions of :math:`\mathbf{L}`.
 | 
			
		||||
 | 
			
		||||
        If the input matrix is not a triangular matrix behaviour is undefined.
 | 
			
		||||
 | 
			
		||||
@@ -351,6 +351,6 @@ void init_linalg(nb::module_& parent_module) {
 | 
			
		||||
              in which case the default stream of the default device is used.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
          array: :math:`A^{-1}` where :math:`\mathbf{A} = \mathbf{L}\mathbf{L}^T`.
 | 
			
		||||
          array: :math:`\mathbf{A^{-1}}` where :math:`\mathbf{A} = \mathbf{L}\mathbf{L}^T`.
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user