mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Simplifications for MLX C (#1396)
* simplifications for MLX C * use vectors instead of map * update examples
This commit is contained in:
		@@ -1,8 +1,8 @@
 | 
			
		||||
// Copyright © 2023-2024 Apple Inc.
 | 
			
		||||
 | 
			
		||||
#include <nanobind/nanobind.h>
 | 
			
		||||
#include <nanobind/stl/map.h>
 | 
			
		||||
#include <nanobind/stl/optional.h>
 | 
			
		||||
#include <nanobind/stl/pair.h>
 | 
			
		||||
#include <nanobind/stl/string.h>
 | 
			
		||||
#include <nanobind/stl/tuple.h>
 | 
			
		||||
#include <nanobind/stl/variant.h>
 | 
			
		||||
@@ -193,39 +193,130 @@ void init_fast(nb::module_& parent_module) {
 | 
			
		||||
          array: The quantized version of ``w``
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
 | 
			
		||||
  nb::class_<fast::MetalKernel>(
 | 
			
		||||
      m,
 | 
			
		||||
  m.def(
 | 
			
		||||
      "metal_kernel",
 | 
			
		||||
      [](const std::string& name,
 | 
			
		||||
         const std::vector<std::string>& input_names,
 | 
			
		||||
         const std::vector<std::string>& output_names,
 | 
			
		||||
         const std::string& source,
 | 
			
		||||
         const std::string& header,
 | 
			
		||||
         bool ensure_row_contiguous,
 | 
			
		||||
         bool atomic_outputs) {
 | 
			
		||||
        auto kernel = fast::metal_kernel(
 | 
			
		||||
            name,
 | 
			
		||||
            input_names,
 | 
			
		||||
            output_names,
 | 
			
		||||
            source,
 | 
			
		||||
            header,
 | 
			
		||||
            ensure_row_contiguous,
 | 
			
		||||
            atomic_outputs);
 | 
			
		||||
        return nb::cpp_function(
 | 
			
		||||
            [kernel = std::move(kernel)](
 | 
			
		||||
                const std::vector<ScalarOrArray>& inputs_,
 | 
			
		||||
                const std::vector<std::vector<int>>& output_shapes,
 | 
			
		||||
                const std::vector<Dtype>& output_dtypes,
 | 
			
		||||
                std::tuple<int, int, int> grid,
 | 
			
		||||
                std::tuple<int, int, int> threadgroup,
 | 
			
		||||
                const std::optional<
 | 
			
		||||
                    std::vector<std::pair<std::string, nb::object>>>&
 | 
			
		||||
                    template_args_ = std::nullopt,
 | 
			
		||||
                std::optional<float> init_value = std::nullopt,
 | 
			
		||||
                bool verbose = false,
 | 
			
		||||
                StreamOrDevice s = {}) {
 | 
			
		||||
              std::vector<array> inputs;
 | 
			
		||||
              for (const auto& value : inputs_) {
 | 
			
		||||
                inputs.push_back(to_array(value, std::nullopt));
 | 
			
		||||
              }
 | 
			
		||||
              std::vector<std::pair<std::string, fast::TemplateArg>>
 | 
			
		||||
                  template_args;
 | 
			
		||||
              if (template_args_) {
 | 
			
		||||
                for (const auto& [name, value] : template_args_.value()) {
 | 
			
		||||
                  // Handle bool, int and dtype template args
 | 
			
		||||
                  if (nb::isinstance<bool>(value)) {
 | 
			
		||||
                    bool bool_val = nb::cast<bool>(value);
 | 
			
		||||
                    template_args.emplace_back(name, bool_val);
 | 
			
		||||
                  } else if (nb::isinstance<int>(value)) {
 | 
			
		||||
                    int int_val = nb::cast<int>(value);
 | 
			
		||||
                    template_args.emplace_back(name, int_val);
 | 
			
		||||
                  } else if (nb::isinstance<Dtype>(value)) {
 | 
			
		||||
                    Dtype dtype = nb::cast<Dtype>(value);
 | 
			
		||||
                    template_args.emplace_back(name, dtype);
 | 
			
		||||
                  } else {
 | 
			
		||||
                    throw std::invalid_argument(
 | 
			
		||||
                        "[[metal_kernel]] Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`.");
 | 
			
		||||
                  }
 | 
			
		||||
                }
 | 
			
		||||
              }
 | 
			
		||||
              return kernel(
 | 
			
		||||
                  inputs,
 | 
			
		||||
                  output_shapes,
 | 
			
		||||
                  output_dtypes,
 | 
			
		||||
                  grid,
 | 
			
		||||
                  threadgroup,
 | 
			
		||||
                  template_args,
 | 
			
		||||
                  init_value,
 | 
			
		||||
                  verbose,
 | 
			
		||||
                  s);
 | 
			
		||||
            },
 | 
			
		||||
            nb::kw_only(),
 | 
			
		||||
            "inputs"_a,
 | 
			
		||||
            "output_shapes"_a,
 | 
			
		||||
            "output_dtypes"_a,
 | 
			
		||||
            "grid"_a,
 | 
			
		||||
            "threadgroup"_a,
 | 
			
		||||
            "template"_a = nb::none(),
 | 
			
		||||
            "init_value"_a = nb::none(),
 | 
			
		||||
            "verbose"_a = false,
 | 
			
		||||
            "stream"_a = nb::none(),
 | 
			
		||||
            nb::sig(
 | 
			
		||||
                "def __call__(self, *, inputs: List[Union[scalar, array]], output_shapes: List[Sequence[int]], output_dtypes: List[Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[List[Tuple[str, Union[bool, int, Dtype]]]] = None, init_value: Optional[float] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"),
 | 
			
		||||
            R"pbdoc(
 | 
			
		||||
           Run the kernel.
 | 
			
		||||
 | 
			
		||||
           Args:
 | 
			
		||||
             inputs (List[array]): The inputs passed to the Metal kernel.
 | 
			
		||||
             output_shapes (List[Sequence[int]]): The list of shapes for each output in ``output_names``.
 | 
			
		||||
             output_dtypes (List[Dtype]): The list of data types for each output in ``output_names``.
 | 
			
		||||
             grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.
 | 
			
		||||
             threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.
 | 
			
		||||
             template (List[Tuple[str, Union[bool, int, Dtype]]], optional): Template arguments.
 | 
			
		||||
                 These will be added as template arguments to the kernel definition. Default: ``None``.
 | 
			
		||||
             init_value (float, optional): Optional value to use to initialize all of the output arrays.
 | 
			
		||||
                 By default, output arrays are uninitialized. Default: ``None``.
 | 
			
		||||
             verbose (bool, optional): Whether to print the full generated source code of the kernel
 | 
			
		||||
                 when it is run. Default: ``False``.
 | 
			
		||||
             stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``.
 | 
			
		||||
 | 
			
		||||
           Returns:
 | 
			
		||||
             List[array]: The list of output arrays.
 | 
			
		||||
        )pbdoc");
 | 
			
		||||
      },
 | 
			
		||||
      "name"_a,
 | 
			
		||||
      "input_names"_a,
 | 
			
		||||
      "output_names"_a,
 | 
			
		||||
      "source"_a,
 | 
			
		||||
      "header"_a = "",
 | 
			
		||||
      "ensure_row_contiguous"_a = true,
 | 
			
		||||
      "atomic_outputs"_a = false,
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
      A jit-compiled custom Metal kernel defined from a source string.
 | 
			
		||||
      )pbdoc")
 | 
			
		||||
      .def(
 | 
			
		||||
          nb::init<
 | 
			
		||||
              const std::string&,
 | 
			
		||||
              const std::string&,
 | 
			
		||||
              const std::string&,
 | 
			
		||||
              bool,
 | 
			
		||||
              bool>(),
 | 
			
		||||
          "name"_a,
 | 
			
		||||
          "source"_a,
 | 
			
		||||
          "header"_a = "",
 | 
			
		||||
          "ensure_row_contiguous"_a = true,
 | 
			
		||||
          "atomic_outputs"_a = false,
 | 
			
		||||
          R"pbdoc(
 | 
			
		||||
      Initialize a metal_kernel.
 | 
			
		||||
 | 
			
		||||
      Args:
 | 
			
		||||
        name (str): Name for the kernel.
 | 
			
		||||
        source (str): Source code. This is the body of a function in Metal,
 | 
			
		||||
            the function signature will be generated for you. The names of the inputs/outputs
 | 
			
		||||
            are determined by the ``inputs`` and ``output_shapes``/``output_dtypes``
 | 
			
		||||
            used when the kernel is called.
 | 
			
		||||
        header (str): Header source code to include before the main function.
 | 
			
		||||
            Useful for helper functions or includes that should live outside of the main function body.
 | 
			
		||||
        ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous
 | 
			
		||||
            before the kernel runs. Default: ``True``.
 | 
			
		||||
        atomic_outputs (bool): Whether to use atomic outputs in the function signature
 | 
			
		||||
            e.g. ``device atomic<float>``. Default: ``False``.
 | 
			
		||||
       name (str): Name for the kernel.
 | 
			
		||||
       input_names (List[str]): The parameter names of the inputs in the
 | 
			
		||||
          function signature.
 | 
			
		||||
       output_names (List[str]): The parameter names of the outputs in the
 | 
			
		||||
           function signature.
 | 
			
		||||
       source (str): Source code. This is the body of a function in Metal,
 | 
			
		||||
           the function signature will be automatically generated.
 | 
			
		||||
       header (str): Header source code to include before the main function.
 | 
			
		||||
           Useful for helper functions or includes that should live outside of
 | 
			
		||||
            the main function body.
 | 
			
		||||
       ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous
 | 
			
		||||
           before the kernel runs. Default: ``True``.
 | 
			
		||||
       atomic_outputs (bool): Whether to use atomic outputs in the function signature
 | 
			
		||||
           e.g. ``device atomic<float>``. Default: ``False``.
 | 
			
		||||
 | 
			
		||||
      Returns:
 | 
			
		||||
        Callable ``metal_kernel``.
 | 
			
		||||
 | 
			
		||||
@@ -242,103 +333,23 @@ void init_fast(nb::module_& parent_module) {
 | 
			
		||||
 | 
			
		||||
              kernel = mx.fast.metal_kernel(
 | 
			
		||||
                  name="myexp",
 | 
			
		||||
                  input_names=["inp"],
 | 
			
		||||
                  output_names=["out"],
 | 
			
		||||
                  source=source
 | 
			
		||||
              )
 | 
			
		||||
              outputs = kernel(
 | 
			
		||||
                  inputs={"inp": a},
 | 
			
		||||
                  template={"T": mx.float32},
 | 
			
		||||
                  inputs=[a],
 | 
			
		||||
                  template=[("T", mx.float32)],
 | 
			
		||||
                  grid=(a.size, 1, 1),
 | 
			
		||||
                  threadgroup=(256, 1, 1),
 | 
			
		||||
                  output_shapes={"out": a.shape},
 | 
			
		||||
                  output_dtypes={"out": a.dtype},
 | 
			
		||||
                  output_shapes=[a.shape],
 | 
			
		||||
                  output_dtypes=[a.dtype],
 | 
			
		||||
                  verbose=True,
 | 
			
		||||
              )
 | 
			
		||||
              return outputs["out"]
 | 
			
		||||
              return outputs[0]
 | 
			
		||||
 | 
			
		||||
          a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
 | 
			
		||||
          b = exp_elementwise(a)
 | 
			
		||||
          assert mx.allclose(b, mx.exp(a))
 | 
			
		||||
      )pbdoc")
 | 
			
		||||
      .def(
 | 
			
		||||
          "__call__",
 | 
			
		||||
          [](fast::MetalKernel& kernel,
 | 
			
		||||
             std::map<std::string, ScalarOrArray>& inputs_,
 | 
			
		||||
             std::map<std::string, std::vector<int>>& output_shapes,
 | 
			
		||||
             std::map<std::string, Dtype>& output_dtypes,
 | 
			
		||||
             std::tuple<int, int, int> grid,
 | 
			
		||||
             std::tuple<int, int, int> threadgroup,
 | 
			
		||||
             std::optional<std::map<std::string, nb::handle>> template_args_,
 | 
			
		||||
             std::optional<float> init_value,
 | 
			
		||||
             bool verbose,
 | 
			
		||||
             StreamOrDevice s) {
 | 
			
		||||
            std::map<std::string, array> inputs;
 | 
			
		||||
            for (const auto& [name, value] : inputs_) {
 | 
			
		||||
              auto arr = to_array(value, std::nullopt);
 | 
			
		||||
              inputs.insert({name, arr});
 | 
			
		||||
            }
 | 
			
		||||
            std::map<std::string, fast::TemplateArg> template_args;
 | 
			
		||||
            if (template_args_) {
 | 
			
		||||
              for (const auto& [name, value] : template_args_.value()) {
 | 
			
		||||
                // Handle bool, int and dtype template args
 | 
			
		||||
                if (nb::isinstance<bool>(value)) {
 | 
			
		||||
                  bool bool_val = nb::cast<bool>(value);
 | 
			
		||||
                  template_args.insert({name, bool_val});
 | 
			
		||||
                } else if (nb::isinstance<int>(value)) {
 | 
			
		||||
                  int int_val = nb::cast<int>(value);
 | 
			
		||||
                  template_args.insert({name, int_val});
 | 
			
		||||
                } else if (nb::isinstance<Dtype>(value)) {
 | 
			
		||||
                  Dtype dtype = nb::cast<Dtype>(value);
 | 
			
		||||
                  template_args.insert({name, dtype});
 | 
			
		||||
                } else {
 | 
			
		||||
                  throw std::invalid_argument(
 | 
			
		||||
                      "[[metal_kernel]] Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`.");
 | 
			
		||||
                }
 | 
			
		||||
              }
 | 
			
		||||
            }
 | 
			
		||||
            return kernel(
 | 
			
		||||
                inputs,
 | 
			
		||||
                output_shapes,
 | 
			
		||||
                output_dtypes,
 | 
			
		||||
                grid,
 | 
			
		||||
                threadgroup,
 | 
			
		||||
                template_args,
 | 
			
		||||
                init_value,
 | 
			
		||||
                verbose,
 | 
			
		||||
                s);
 | 
			
		||||
          },
 | 
			
		||||
          nb::kw_only(),
 | 
			
		||||
          "inputs"_a,
 | 
			
		||||
          "output_shapes"_a,
 | 
			
		||||
          "output_dtypes"_a,
 | 
			
		||||
          "grid"_a,
 | 
			
		||||
          "threadgroup"_a,
 | 
			
		||||
          "template"_a = nb::none(),
 | 
			
		||||
          "init_value"_a = nb::none(),
 | 
			
		||||
          "verbose"_a = false,
 | 
			
		||||
          "stream"_a = nb::none(),
 | 
			
		||||
          nb::sig(
 | 
			
		||||
              "def __call__(self, *, inputs: Mapping[str, Union[scalar, array]], output_shapes: Mapping[str, Sequence[int]], output_dtypes: Mapping[str, Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[Mapping[str, Union[bool, int, Dtype]]] = None, init_value: Optional[float] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"),
 | 
			
		||||
          R"pbdoc(
 | 
			
		||||
            Run the kernel.
 | 
			
		||||
 | 
			
		||||
            Args:
 | 
			
		||||
              inputs (Mapping[str, array]): Inputs. These will be added to the function signature and passed to the Metal kernel.
 | 
			
		||||
                  The keys will be the names of the arguments to the kernel.
 | 
			
		||||
              output_shapes (Mapping[str, Sequence[int]]): Output shapes. A dict mapping
 | 
			
		||||
                  output variable names to shapes. These will be added to the function signature.
 | 
			
		||||
              output_dtypes (Mapping[str, Dtype]): Output dtypes. A dict mapping output variable
 | 
			
		||||
                  names to dtypes. Must have the same keys as ``output_shapes``.
 | 
			
		||||
              grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.
 | 
			
		||||
              threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.
 | 
			
		||||
              template (Mapping[str, Union[bool, int, Dtype]], optional): Template arguments.
 | 
			
		||||
                  These will be added as template arguments to the kernel definition. Default: ``None``.
 | 
			
		||||
              init_value (float, optional): Optional value to use to initialize all of the output arrays.
 | 
			
		||||
                  By default, output arrays are uninitialized. Default: ``None``.
 | 
			
		||||
              verbose (bool, optional): Whether to print the full generated source code of the kernel
 | 
			
		||||
                  when it is run. Default: ``False``.
 | 
			
		||||
              stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``.
 | 
			
		||||
 | 
			
		||||
            Returns:
 | 
			
		||||
              dict[str, array]: Dictionary of output arrays based on ``output_shapes``/``output_dtypes``.
 | 
			
		||||
            )pbdoc");
 | 
			
		||||
     )pbdoc");
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user