diff --git a/docs/src/dev/custom_metal_kernels.rst b/docs/src/dev/custom_metal_kernels.rst index 04cede771..c4c1b0aff 100644 --- a/docs/src/dev/custom_metal_kernels.rst +++ b/docs/src/dev/custom_metal_kernels.rst @@ -19,17 +19,19 @@ Let's write a custom kernel that computes ``exp`` elementwise: kernel = mx.fast.metal_kernel( name="myexp", + input_names=["inp"], + output_names=["out"], source=source, ) outputs = kernel( - inputs={"inp": a}, - template={"T": mx.float32}, + inputs=[a], + template=[("T", mx.float32)], grid=(a.size, 1, 1), threadgroup=(256, 1, 1), - output_shapes={"out": a.shape}, - output_dtypes={"out": a.dtype}, + output_shapes=[a.shape], + output_dtypes=[a.dtype], ) - return outputs["out"] + return outputs[0] a = mx.random.normal(shape=(4, 16)).astype(mx.float16) b = exp_elementwise(a) @@ -40,16 +42,16 @@ Let's write a custom kernel that computes ``exp`` elementwise: The full function signature will be generated using: -* The keys and shapes/dtypes of ``inputs`` +* The shapes/dtypes of ``inputs`` In the above, ``a`` is an ``mx.array`` of type ``mx.float16`` and we pass it with the key ``inp`` so we will add ``const device float16_t* inp`` to the signature. ``inp_shape``, ``inp_strides`` and ``inp_ndim`` are also added for convenience if they are present in ``source``. -* The keys and values of ``output_shapes`` and ``output_dtypes`` +* The list of ``output_dtypes`` In the above, ``out`` is an ``mx.array`` of type ``mx.float16`` so we add ``device float16_t* out``. * Template parameters passed using ``template`` - In the above, ``template={"T": mx.float32}`` adds a template of ``template `` to the function + In the above, ``template=[("T", mx.float32)]`` adds a template of ``template `` to the function and instantiates the template with ``custom_kernel_myexp_float``. Template parameters can be ``mx.core.Dtype``, ``int`` or ``bool``. * Metal attributes used in ``source`` such as ``[[thread_position_in_grid]]`` @@ -104,18 +106,20 @@ Let's convert ``myexp`` above to support arbitrarily strided arrays without rely kernel = mx.fast.metal_kernel( name="myexp_strided", + input_names=["inp"], + output_names=["out"], source=source ) outputs = kernel( - inputs={"inp": a}, - template={"T": mx.float32}, + inputs=[a], + template=[("T", mx.float32)], grid=(a.size, 1, 1), threadgroup=(256, 1, 1), - output_shapes={"out": a.shape}, - output_dtypes={"out": a.dtype}, + output_shapes=[a.shape], + output_dtypes=[a.dtype], ensure_row_contiguous=False, ) - return outputs["out"] + return outputs[0] a = mx.random.normal(shape=(4, 16)).astype(mx.float16) # make non-contiguous @@ -243,17 +247,19 @@ First we'll implement the forward pass as a fused kernel: """ kernel = mx.fast.metal_kernel( name="grid_sample", + input_names=["x", "grid"], + output_names=["out"], source=source, ) outputs = kernel( - inputs={"x": x, "grid": grid}, - template={"T": x.dtype}, - output_shapes={"out": out_shape}, - output_dtypes={"out": x.dtype}, + inputs=[x, grid], + template=[("T", x.dtype)], + output_shapes=[out_shape], + output_dtypes=[x.dtype], grid=(np.prod(out_shape), 1, 1), threadgroup=(256, 1, 1), ) - return outputs["out"] + return outputs[0] For a reasonably sized input such as: @@ -389,6 +395,8 @@ We can then implement the backwards pass as follows: """ kernel = mx.fast.metal_kernel( name="grid_sample_grad", + input_names=["x", "grid", "cotangent"], + output_names=["x_grad", "grid_grad"], source=source, atomic_outputs=True, ) @@ -398,15 +406,15 @@ We can then implement the backwards pass as follows: C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size grid_size = B * gN * gM * C_padded outputs = kernel( - inputs={"x": x, "grid": grid, "cotangent": cotangent}, - template={"T": x.dtype}, - output_shapes={"x_grad": x.shape, "grid_grad": grid.shape}, - output_dtypes={"x_grad": x.dtype, "grid_grad": x.dtype}, + inputs=[x, grid, cotangent], + template=[("T", x.dtype)], + output_shapes=[x.shape, grid.shape], + output_dtypes=[x.dtype, x.dtype], grid=(grid_size, 1, 1), threadgroup=(256, 1, 1), init_value=0, ) - return outputs["x_grad"], outputs["grid_grad"] + return outputs[0], outputs[1] There's an even larger speed up for the vjp: diff --git a/mlx/fast.cpp b/mlx/fast.cpp index c8c12af69..1a2afaaa0 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -515,7 +515,7 @@ array scaled_dot_product_attention( const array& values, const float scale, const std::optional& mask, - const std::optional& memory_efficient_threshold, + const std::optional memory_efficient_threshold, StreamOrDevice s) { for (const auto& tensor : {queries, keys, values}) { if (tensor.ndim() != 4) { @@ -916,47 +916,23 @@ array affine_dequantize( return fallback({w, scales, biases})[0]; } -void validate_output_shapes( - std::map> output_shapes, - std::map output_dtypes) { - // Make sure output shapes and dtypes have the same keys - bool validated = true; - if (output_shapes.size() == 0) { - throw std::invalid_argument( - "[metal_kernel] Must specify at least one output."); - } - if (output_shapes.size() != output_dtypes.size()) { - validated = false; - } else { - for (const auto& kv : output_shapes) { - if (output_dtypes.find(kv.first) == output_dtypes.end()) { - validated = false; - break; - } - } - } - if (!validated) { - throw std::invalid_argument( - "[metal_kernel] `output_shapes` and `output_dtypes` must have the same keys."); - } -} - void write_signature( std::string func_name, - std::string& source, - std::map& inputs, - std::map>& output_shapes, - std::map& output_dtypes, - std::optional> template_args, + const std::string& source, + const std::vector& input_names, + const std::vector& inputs, + const std::vector& output_names, + const std::vector& output_dtypes, + const std::vector>& template_args, std::vector& shape_infos, bool atomic_outputs, std::ostringstream& kernel_source) { // Auto-generate a function signature based on `template_args` // and the dtype/shape of the arrays passed as `inputs`. - if (template_args && template_args.value().size() > 0) { + if (!template_args.empty()) { kernel_source << "template <"; int i = 0; - for (const auto& [name, arg] : template_args.value()) { + for (const auto& [name, arg] : template_args) { std::string param_type; if (std::holds_alternative(arg)) { param_type = "int"; @@ -1008,7 +984,9 @@ void write_signature( int index = 0; constexpr int max_constant_array_size = 8; // Add inputs - for (const auto& [name, arr] : inputs) { + for (int i = 0; i < inputs.size(); ++i) { + const auto& name = input_names[i]; + const auto& arr = inputs[i]; auto dtype = get_type_string(arr.dtype()); bool is_constant = arr.is_available() && arr.size() < max_constant_array_size; @@ -1042,7 +1020,9 @@ void write_signature( shape_infos.push_back(shape_info); } // Add outputs - for (const auto& [name, dtype] : output_dtypes) { + for (int i = 0; i < output_names.size(); ++i) { + const auto& name = output_names[i]; + const auto& dtype = output_dtypes[i]; kernel_source << " device "; auto type_string = get_type_string(dtype); if (atomic_outputs) { @@ -1051,7 +1031,7 @@ void write_signature( kernel_source << type_string; } kernel_source << "* " << name << " [[buffer(" << index << ")]]"; - if (index < inputs.size() + output_shapes.size() - 1 || attrs.size() > 0) { + if (index < inputs.size() + output_names.size() - 1 || attrs.size() > 0) { kernel_source << "," << std::endl; } else { kernel_source << ") {" << std::endl; @@ -1073,7 +1053,8 @@ void write_signature( kernel_source << "}" << std::endl; } -std::string write_template(std::map& template_args) { +std::string write_template( + const std::vector>& template_args) { std::ostringstream template_def; template_def << "<"; int i = 0; @@ -1094,107 +1075,115 @@ std::string write_template(std::map& template_args) { return template_def.str(); } -std::map MetalKernel::operator()( - std::map& inputs, - std::map> output_shapes, - std::map output_dtypes, - std::tuple grid, - std::tuple threadgroup, - std::optional> template_args, - std::optional init_value, - bool verbose, - StreamOrDevice s_) { - validate_output_shapes(output_shapes, output_dtypes); - - auto s = to_stream(s_); - if (s.device != Device::gpu) { +MetalKernelFunction metal_kernel( + const std::string& name, + const std::vector& input_names, + const std::vector& output_names, + const std::string& source, + const std::string& header /* = "" */, + bool ensure_row_contiguous /* = true */, + bool atomic_outputs /* = false */) { + if (output_names.empty()) { throw std::invalid_argument( - "[metal_kernel] MetalKernel only works on GPU."); + "[metal_kernel] Must specify at least one output."); } - std::ostringstream func_name; + return [=](const std::vector& inputs, + const std::vector>& output_shapes, + const std::vector& output_dtypes, + std::tuple grid, + std::tuple threadgroup, + const std::vector>& + template_args = {}, + std::optional init_value = std::nullopt, + bool verbose = false, + StreamOrDevice s_ = {}) { + if (inputs.size() != input_names.size()) { + std::ostringstream msg; + msg << "[metal_kernel] Expected `inputs` to have size " + << input_names.size() << " but got size " << inputs.size() << "." + << std::endl; + throw std::invalid_argument(msg.str()); + } + if (output_shapes.size() != output_names.size()) { + std::ostringstream msg; + msg << "[metal_kernel] Expected `output_shapes` to have size " + << output_names.size() << " but got size " << output_shapes.size() + << "." << std::endl; + throw std::invalid_argument(msg.str()); + } + if (output_dtypes.size() != output_names.size()) { + std::ostringstream msg; + msg << "[metal_kernel] Expected `output_dtypes` to have size " + << output_names.size() << " but got size " << output_dtypes.size() + << "." << std::endl; + throw std::invalid_argument(msg.str()); + } - std::string template_def = ""; - bool needs_template = template_args && template_args.value().size() > 0; - std::string hash_key = ""; - if (needs_template) { - std::regex disallowed_chars("\\<|\\>|(, )"); - template_def = write_template(template_args.value()); - hash_key = std::regex_replace(template_def, disallowed_chars, "_"); - hash_key.pop_back(); - } + auto s = to_stream(s_); + if (s.device != Device::gpu) { + throw std::invalid_argument("[metal_kernel] Only supports the GPU."); + } - func_name << "custom_kernel_" << name_ << hash_key; - std::string kernel_name = func_name.str(); + std::ostringstream func_name; - std::ostringstream kernel_source; - kernel_source << header_ << std::endl; + std::string template_def = ""; + std::string hash_key = ""; + if (!template_args.empty()) { + std::regex disallowed_chars("\\<|\\>|(, )"); + template_def = write_template(template_args); + hash_key = std::regex_replace(template_def, disallowed_chars, "_"); + hash_key.pop_back(); + } - std::vector shape_infos; - write_signature( - func_name.str(), - source_, - inputs, - output_shapes, - output_dtypes, - template_args, - shape_infos, - atomic_outputs_, - kernel_source); + func_name << "custom_kernel_" << name << hash_key; + std::string kernel_name = func_name.str(); - if (needs_template) { - template_def = func_name.str() + template_def; - kernel_source << std::endl - << "template [[host_name(\"" << kernel_name - << "\")]] [[kernel]] decltype(" << template_def << ") " - << template_def << ";" << std::endl; - } + std::ostringstream kernel_source; + kernel_source << header << std::endl; - if (verbose) { - std::cout << "Generated source code for `" << name_ << "`:" << std::endl - << "```" << std::endl - << kernel_source.str() << std::endl - << "```" << std::endl; - } + std::vector shape_infos; + write_signature( + func_name.str(), + source, + input_names, + inputs, + output_names, + output_dtypes, + template_args, + shape_infos, + atomic_outputs, + kernel_source); - std::vector in_arrs; - for (const auto& kv : inputs) { - in_arrs.push_back(kv.second); - } + if (!template_args.empty()) { + template_def = func_name.str() + template_def; + kernel_source << std::endl + << "template [[host_name(\"" << kernel_name + << "\")]] [[kernel]] decltype(" << template_def << ") " + << template_def << ";" << std::endl; + } - std::vector out_keys; - std::vector> out_shapes; - for (const auto& [name, shape] : output_shapes) { - out_keys.push_back(name); - out_shapes.push_back(shape); - } + if (verbose) { + std::cout << "Generated source code for `" << name << "`:" << std::endl + << "```" << std::endl + << kernel_source.str() << std::endl + << "```" << std::endl; + } - std::vector out_dtypes; - for (const auto& kv : output_dtypes) { - out_dtypes.push_back(kv.second); - } - - std::map outputs; - auto outputs_vec = array::make_arrays( - out_shapes, - out_dtypes, - std::make_shared( - s, - kernel_name, - kernel_source.str(), - grid, - threadgroup, - shape_infos, - ensure_row_contiguous_, - init_value), - in_arrs); - - int i = 0; - for (const auto& key : out_keys) { - outputs.insert({key, outputs_vec[i]}); - i++; - } - return outputs; + return array::make_arrays( + output_shapes, + output_dtypes, + std::make_shared( + s, + kernel_name, + kernel_source.str(), + grid, + threadgroup, + shape_infos, + ensure_row_contiguous, + init_value), + inputs); + }; } } // namespace mlx::core::fast diff --git a/mlx/fast.h b/mlx/fast.h index 874aa529a..e1a876882 100644 --- a/mlx/fast.h +++ b/mlx/fast.h @@ -2,7 +2,6 @@ #pragma once -#include #include #include "mlx/utils.h" @@ -39,7 +38,7 @@ array scaled_dot_product_attention( const array& values, const float scale, const std::optional& mask = std::nullopt, - const std::optional& memory_efficient_threshold = std::nullopt, + const std::optional memory_efficient_threshold = std::nullopt, StreamOrDevice s = {}); std::tuple affine_quantize( @@ -66,37 +65,25 @@ array affine_dequantize( typedef std::variant TemplateArg; -class MetalKernel { - public: - MetalKernel( - const std::string& name, - const std::string& source, - const std::string& header = "", - bool ensure_row_contiguous = true, - bool atomic_outputs = false) - : name_(name), - source_(source), - header_(header), - ensure_row_contiguous_(ensure_row_contiguous), - atomic_outputs_(atomic_outputs) {} +typedef std::function( + const std::vector&, + const std::vector>&, + const std::vector&, + std::tuple, + std::tuple, + std::vector>, + std::optional, + bool, + StreamOrDevice)> + MetalKernelFunction; - std::map operator()( - std::map& inputs, - std::map> output_shapes, - std::map output_dtypes, - std::tuple grid, - std::tuple threadgroup, - std::optional> template_args = - std::nullopt, - std::optional init_value = std::nullopt, - bool verbose = false, - StreamOrDevice s = {}); +MetalKernelFunction metal_kernel( + const std::string& name, + const std::vector& input_names, + const std::vector& output_names, + const std::string& source, + const std::string& header = "", + bool ensure_row_contiguous = true, + bool atomic_outputs = false); - private: - std::string name_; - std::string source_; - std::string header_; - bool ensure_row_contiguous_; - bool atomic_outputs_; -}; } // namespace mlx::core::fast diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 56f8fa234..9b9a0dcea 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1425,8 +1425,8 @@ array where( array nan_to_num( const array& a, float nan /* = 0.0f */, - const std::optional& posinf_ /* = std::nullopt */, - const std::optional& neginf_ /* = std::nullopt */, + const std::optional posinf_ /* = std::nullopt */, + const std::optional neginf_ /* = std::nullopt */, StreamOrDevice s /* = {} */) { Dtype dtype = a.dtype(); if (!issubdtype(dtype, inexact)) { diff --git a/mlx/ops.h b/mlx/ops.h index daff9bcdc..711a37aa9 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -416,8 +416,8 @@ array where( array nan_to_num( const array& a, float nan = 0.0f, - const std::optional& posinf = std::nullopt, - const std::optional& neginf = std::nullopt, + const std::optional posinf = std::nullopt, + const std::optional neginf = std::nullopt, StreamOrDevice s = {}); /** True if all elements in the array are true (or non-zero). **/ diff --git a/python/src/fast.cpp b/python/src/fast.cpp index b7b891933..b07e965e9 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -1,8 +1,8 @@ // Copyright © 2023-2024 Apple Inc. #include -#include #include +#include #include #include #include @@ -193,39 +193,130 @@ void init_fast(nb::module_& parent_module) { array: The quantized version of ``w`` )pbdoc"); - nb::class_( - m, + m.def( "metal_kernel", + [](const std::string& name, + const std::vector& input_names, + const std::vector& output_names, + const std::string& source, + const std::string& header, + bool ensure_row_contiguous, + bool atomic_outputs) { + auto kernel = fast::metal_kernel( + name, + input_names, + output_names, + source, + header, + ensure_row_contiguous, + atomic_outputs); + return nb::cpp_function( + [kernel = std::move(kernel)]( + const std::vector& inputs_, + const std::vector>& output_shapes, + const std::vector& output_dtypes, + std::tuple grid, + std::tuple threadgroup, + const std::optional< + std::vector>>& + template_args_ = std::nullopt, + std::optional init_value = std::nullopt, + bool verbose = false, + StreamOrDevice s = {}) { + std::vector inputs; + for (const auto& value : inputs_) { + inputs.push_back(to_array(value, std::nullopt)); + } + std::vector> + template_args; + if (template_args_) { + for (const auto& [name, value] : template_args_.value()) { + // Handle bool, int and dtype template args + if (nb::isinstance(value)) { + bool bool_val = nb::cast(value); + template_args.emplace_back(name, bool_val); + } else if (nb::isinstance(value)) { + int int_val = nb::cast(value); + template_args.emplace_back(name, int_val); + } else if (nb::isinstance(value)) { + Dtype dtype = nb::cast(value); + template_args.emplace_back(name, dtype); + } else { + throw std::invalid_argument( + "[[metal_kernel]] Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`."); + } + } + } + return kernel( + inputs, + output_shapes, + output_dtypes, + grid, + threadgroup, + template_args, + init_value, + verbose, + s); + }, + nb::kw_only(), + "inputs"_a, + "output_shapes"_a, + "output_dtypes"_a, + "grid"_a, + "threadgroup"_a, + "template"_a = nb::none(), + "init_value"_a = nb::none(), + "verbose"_a = false, + "stream"_a = nb::none(), + nb::sig( + "def __call__(self, *, inputs: List[Union[scalar, array]], output_shapes: List[Sequence[int]], output_dtypes: List[Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[List[Tuple[str, Union[bool, int, Dtype]]]] = None, init_value: Optional[float] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"), + R"pbdoc( + Run the kernel. + + Args: + inputs (List[array]): The inputs passed to the Metal kernel. + output_shapes (List[Sequence[int]]): The list of shapes for each output in ``output_names``. + output_dtypes (List[Dtype]): The list of data types for each output in ``output_names``. + grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with. + threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use. + template (List[Tuple[str, Union[bool, int, Dtype]]], optional): Template arguments. + These will be added as template arguments to the kernel definition. Default: ``None``. + init_value (float, optional): Optional value to use to initialize all of the output arrays. + By default, output arrays are uninitialized. Default: ``None``. + verbose (bool, optional): Whether to print the full generated source code of the kernel + when it is run. Default: ``False``. + stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``. + + Returns: + List[array]: The list of output arrays. + )pbdoc"); + }, + "name"_a, + "input_names"_a, + "output_names"_a, + "source"_a, + "header"_a = "", + "ensure_row_contiguous"_a = true, + "atomic_outputs"_a = false, R"pbdoc( A jit-compiled custom Metal kernel defined from a source string. - )pbdoc") - .def( - nb::init< - const std::string&, - const std::string&, - const std::string&, - bool, - bool>(), - "name"_a, - "source"_a, - "header"_a = "", - "ensure_row_contiguous"_a = true, - "atomic_outputs"_a = false, - R"pbdoc( - Initialize a metal_kernel. Args: - name (str): Name for the kernel. - source (str): Source code. This is the body of a function in Metal, - the function signature will be generated for you. The names of the inputs/outputs - are determined by the ``inputs`` and ``output_shapes``/``output_dtypes`` - used when the kernel is called. - header (str): Header source code to include before the main function. - Useful for helper functions or includes that should live outside of the main function body. - ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous - before the kernel runs. Default: ``True``. - atomic_outputs (bool): Whether to use atomic outputs in the function signature - e.g. ``device atomic``. Default: ``False``. + name (str): Name for the kernel. + input_names (List[str]): The parameter names of the inputs in the + function signature. + output_names (List[str]): The parameter names of the outputs in the + function signature. + source (str): Source code. This is the body of a function in Metal, + the function signature will be automatically generated. + header (str): Header source code to include before the main function. + Useful for helper functions or includes that should live outside of + the main function body. + ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous + before the kernel runs. Default: ``True``. + atomic_outputs (bool): Whether to use atomic outputs in the function signature + e.g. ``device atomic``. Default: ``False``. + Returns: Callable ``metal_kernel``. @@ -242,103 +333,23 @@ void init_fast(nb::module_& parent_module) { kernel = mx.fast.metal_kernel( name="myexp", + input_names=["inp"], + output_names=["out"], source=source ) outputs = kernel( - inputs={"inp": a}, - template={"T": mx.float32}, + inputs=[a], + template=[("T", mx.float32)], grid=(a.size, 1, 1), threadgroup=(256, 1, 1), - output_shapes={"out": a.shape}, - output_dtypes={"out": a.dtype}, + output_shapes=[a.shape], + output_dtypes=[a.dtype], verbose=True, ) - return outputs["out"] + return outputs[0] a = mx.random.normal(shape=(4, 16)).astype(mx.float16) b = exp_elementwise(a) assert mx.allclose(b, mx.exp(a)) - )pbdoc") - .def( - "__call__", - [](fast::MetalKernel& kernel, - std::map& inputs_, - std::map>& output_shapes, - std::map& output_dtypes, - std::tuple grid, - std::tuple threadgroup, - std::optional> template_args_, - std::optional init_value, - bool verbose, - StreamOrDevice s) { - std::map inputs; - for (const auto& [name, value] : inputs_) { - auto arr = to_array(value, std::nullopt); - inputs.insert({name, arr}); - } - std::map template_args; - if (template_args_) { - for (const auto& [name, value] : template_args_.value()) { - // Handle bool, int and dtype template args - if (nb::isinstance(value)) { - bool bool_val = nb::cast(value); - template_args.insert({name, bool_val}); - } else if (nb::isinstance(value)) { - int int_val = nb::cast(value); - template_args.insert({name, int_val}); - } else if (nb::isinstance(value)) { - Dtype dtype = nb::cast(value); - template_args.insert({name, dtype}); - } else { - throw std::invalid_argument( - "[[metal_kernel]] Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`."); - } - } - } - return kernel( - inputs, - output_shapes, - output_dtypes, - grid, - threadgroup, - template_args, - init_value, - verbose, - s); - }, - nb::kw_only(), - "inputs"_a, - "output_shapes"_a, - "output_dtypes"_a, - "grid"_a, - "threadgroup"_a, - "template"_a = nb::none(), - "init_value"_a = nb::none(), - "verbose"_a = false, - "stream"_a = nb::none(), - nb::sig( - "def __call__(self, *, inputs: Mapping[str, Union[scalar, array]], output_shapes: Mapping[str, Sequence[int]], output_dtypes: Mapping[str, Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[Mapping[str, Union[bool, int, Dtype]]] = None, init_value: Optional[float] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"), - R"pbdoc( - Run the kernel. - - Args: - inputs (Mapping[str, array]): Inputs. These will be added to the function signature and passed to the Metal kernel. - The keys will be the names of the arguments to the kernel. - output_shapes (Mapping[str, Sequence[int]]): Output shapes. A dict mapping - output variable names to shapes. These will be added to the function signature. - output_dtypes (Mapping[str, Dtype]): Output dtypes. A dict mapping output variable - names to dtypes. Must have the same keys as ``output_shapes``. - grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with. - threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use. - template (Mapping[str, Union[bool, int, Dtype]], optional): Template arguments. - These will be added as template arguments to the kernel definition. Default: ``None``. - init_value (float, optional): Optional value to use to initialize all of the output arrays. - By default, output arrays are uninitialized. Default: ``None``. - verbose (bool, optional): Whether to print the full generated source code of the kernel - when it is run. Default: ``False``. - stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``. - - Returns: - dict[str, array]: Dictionary of output arrays based on ``output_shapes``/``output_dtypes``. - )pbdoc"); + )pbdoc"); } diff --git a/python/tests/test_fast.py b/python/tests/test_fast.py index c881eced6..db107eec1 100644 --- a/python/tests/test_fast.py +++ b/python/tests/test_fast.py @@ -562,20 +562,22 @@ class TestFast(mlx_tests.MLXTestCase): a = mx.random.normal(shape=(2, 2)) kernel = mx.fast.metal_kernel( name="basic", + input_names=["a"], + output_names=["out1"], source=""" uint elem = thread_position_in_grid.x; out1[elem] = a[elem]; """, ) out = kernel( - inputs={"a": a}, + inputs=[a], grid=(4, 1, 1), threadgroup=(2, 1, 1), - output_shapes={"out1": (2, 2)}, - output_dtypes={"out1": mx.float32}, + output_shapes=[(2, 2)], + output_dtypes=[mx.float32], stream=mx.gpu, ) - self.assertTrue(mx.allclose(out["out1"], a)) + self.assertTrue(mx.allclose(out[0], a)) @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") def test_custom_kernel_args(self): @@ -585,6 +587,8 @@ class TestFast(mlx_tests.MLXTestCase): kernel = mx.fast.metal_kernel( name="arg_test", + input_names=["a", "b", "c", "d"], + output_names=["out1", "out2"], source=""" uint elem = thread_position_in_grid.x; T tmp = a[0]; @@ -597,26 +601,26 @@ class TestFast(mlx_tests.MLXTestCase): """, ) out = kernel( - inputs={ - "a": a, - "b": mx.array([3, 4, 5]), - "c": c, - "d": 7.3, - }, - template={ - "e": True, - "f": 3, - "T": mx.float16, - }, + inputs=[ + a, + mx.array([3, 4, 5]), + c, + 7.3, + ], + template=[ + ("e", True), + ("f", 3), + ("T", mx.float16), + ], grid=(6, 1, 1), threadgroup=(2, 1, 1), - output_shapes={"out1": (2, 2), "out2": (3, 2)}, - output_dtypes={"out1": mx.float32, "out2": mx.int32}, + output_shapes=[(2, 2), (3, 2)], + output_dtypes=[mx.float32, mx.int32], stream=mx.gpu, ) - self.assertTrue(mx.allclose(out["out1"], mx.full((2, 2), 14.0484))) - self.assertTrue(mx.allclose(out["out2"], mx.full((3, 2), -2, dtype=mx.int32))) + self.assertTrue(mx.allclose(out[0], mx.full((2, 2), 14.0484))) + self.assertTrue(mx.allclose(out[1], mx.full((3, 2), -2, dtype=mx.int32))) @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") def test_custom_kernel_strides(self): @@ -640,19 +644,21 @@ class TestFast(mlx_tests.MLXTestCase): for contig in [True, False]: kernel = mx.fast.metal_kernel( name="myexp" + str(contig), + input_names=["inp"], + output_names=["out"], source=source_contig if contig else source, ensure_row_contiguous=contig, ) outputs = kernel( - inputs={"inp": a}, - template={"T": mx.float32}, + inputs=[a], + template=[("T", mx.float32)], grid=(a.size, 1, 1), threadgroup=(256, 1, 1), - output_shapes={"out": a.shape}, - output_dtypes={"out": a.dtype}, + output_shapes=[a.shape], + output_dtypes=[a.dtype], stream=mx.gpu, ) - self.assertTrue(mx.allclose(mx.exp(a) * 32, outputs["out"])) + self.assertTrue(mx.allclose(mx.exp(a) * 32, outputs[0])) @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") def test_custom_kernel_helper(self): @@ -660,6 +666,8 @@ class TestFast(mlx_tests.MLXTestCase): a = mx.random.normal(shape=(2, 2)) kernel = mx.fast.metal_kernel( name="helper", + input_names=["a"], + output_names=["out1"], header=""" template T do_exp(T x) { @@ -672,14 +680,14 @@ class TestFast(mlx_tests.MLXTestCase): """, ) out = kernel( - inputs={"a": a}, + inputs=[a], grid=(4, 1, 1), threadgroup=(2, 1, 1), - output_shapes={"out1": (2, 2)}, - output_dtypes={"out1": mx.float32}, + output_shapes=[(2, 2)], + output_dtypes=[mx.float32], stream=mx.gpu, ) - self.assertTrue(mx.allclose(out["out1"], mx.exp(a))) + self.assertTrue(mx.allclose(out[0], mx.exp(a))) if __name__ == "__main__":