diff --git a/docs/src/dev/custom_metal_kernels.rst b/docs/src/dev/custom_metal_kernels.rst new file mode 100644 index 000000000..12204d24a --- /dev/null +++ b/docs/src/dev/custom_metal_kernels.rst @@ -0,0 +1,123 @@ +Custom Metal Kernels +==================== + +MLX supports writing custom Metal kernels through the Python and C++ APIs. + +Simple Example +-------------- + +Let's write a custom kernel that computes ``exp`` elementwise: + +.. code-block:: python + + def exp_elementwise(a: mx.array): + source = """ + uint elem = thread_position_in_grid.x; + T tmp = inp[elem]; + out[elem] = metal::exp(tmp); + """ + + kernel = mx.fast.metal_kernel( + name="myexp", + source=source, + ) + outputs = kernel( + inputs={"inp": a}, + template={"T": mx.float32}, + grid=(a.size, 1, 1), + threadgroup=(256, 1, 1), + output_shapes={"out": a.shape}, + output_dtypes={"out": a.dtype}, + ) + return outputs["out"] + + a = mx.random.normal(shape=(4, 16)).astype(mx.float16) + b = exp_elementwise(a) + assert mx.allclose(b, mx.exp(a)) + +.. note:: + We are only required to pass the body of the Metal kernel in ``source``. + +The full function signature will be generated using: + +* The keys and shapes/dtypes of ``inputs`` + In the above, ``a`` is an ``mx.array`` of type ``mx.float16`` and we pass it with the key ``inp`` + so we will add ``const device float16_t* inp`` to the signature. + ``inp_shape``, ``inp_strides`` and ``inp_ndim`` are also added for convenience. +* The keys and values of ``output_shapes`` and ``output_dtypes`` + In the above, ``out`` is an ``mx.array`` of type ``mx.float16`` + so we add ``device float16_t* out``. +* Template parameters passed using ``template`` + In the above, ``template={"T": mx.float32}`` adds a template of ``template `` to the function + and instantiates the template with ``custom_kernel_myexp_float``. + Template parameters can be ``mx.core.Dtype``, ``int`` or ``bool``. +* Metal attributes used in ``source`` such as ``[[thread_position_in_grid]]`` + These will be added as function arguments. + All the attributes defined in Table 5.8 of the `Metal Shading Language Specification `_ are supported. + +Putting this all together, the generated function signature for ``myexp`` is as follows: + +.. code-block:: cpp + + template + [[kernel]] void custom_kernel_myexp_float( + const device float16_t* inp [[buffer(0)]], + device float16_t* out [[buffer(1)]], + uint3 thread_position_in_grid [[thread_position_in_grid]]) { + + uint elem = thread_position_in_grid.x; + T tmp = inp[elem]; + out[elem] = metal::exp(tmp); + + } + + template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float) custom_kernel_myexp_float; + +You can print the generated code for a ``mx.fast.metal_kernel`` by passing ``verbose=True`` when you call it. + +Using Shape/Strides +------------------- + +``mx.fast.metal_kernel`` supports an argument ``ensure_row_contiguous`` which is ``True`` by default. +This will copy the ``mx.array`` inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous. +Generally this makes writing the kernel easier, since we don't have to worry about gaps or the ordering of the dims +when indexing. + +If we want to avoid this copy, ``metal_kernel`` automatically passes ``a_shape``, ``a_strides`` and ``a_ndim`` for each +input array ``a`` if any are present in ``source``. +We can then use MLX's built in indexing utils to fetch the right elements for each thread. + +Let's convert ``myexp`` above to support arbitrarily strided arrays without relying on a copy from ``ensure_row_contiguous``: + +.. code-block:: python + + def exp_elementwise(a: mx.array): + source = """ + uint elem = thread_position_in_grid.x; + // Utils from `mlx/backend/metal/kernels/utils.h` are automatically included + uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim); + T tmp = inp[loc]; + // Output arrays are always row contiguous + out[elem] = metal::exp(tmp); + """ + + kernel = mx.fast.metal_kernel( + name="myexp_strided", + source=source + ) + outputs = kernel( + inputs={"inp": a}, + template={"T": mx.float32}, + grid=(a.size, 1, 1), + threadgroup=(256, 1, 1), + output_shapes={"out": a.shape}, + output_dtypes={"out": a.dtype}, + ensure_row_contiguous=False, + ) + return outputs["out"] + + a = mx.random.normal(shape=(4, 16)).astype(mx.float16) + # make non-contiguous + a = a[::2] + b = exp_elementwise(a) + assert mx.allclose(b, mx.exp(a)) diff --git a/docs/src/index.rst b/docs/src/index.rst index fd5147ca6..1e5e6ad8a 100644 --- a/docs/src/index.rst +++ b/docs/src/index.rst @@ -85,3 +85,4 @@ are the CPU and GPU. dev/extensions dev/metal_debugger + dev/custom_metal_kernels diff --git a/docs/src/python/fast.rst b/docs/src/python/fast.rst index 26bd62a26..30ade264e 100644 --- a/docs/src/python/fast.rst +++ b/docs/src/python/fast.rst @@ -12,3 +12,5 @@ Fast layer_norm rope scaled_dot_product_attention + affine_quantize + metal_kernel diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index 38bb1589d..5c470e4cf 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -131,6 +131,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/event.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp diff --git a/mlx/backend/metal/custom_kernel.cpp b/mlx/backend/metal/custom_kernel.cpp new file mode 100644 index 000000000..b7c3f7ebf --- /dev/null +++ b/mlx/backend/metal/custom_kernel.cpp @@ -0,0 +1,84 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/backend/metal/copy.h" +#include "mlx/backend/metal/jit/includes.h" +#include "mlx/backend/metal/utils.h" +#include "mlx/fast_primitives.h" + +namespace mlx::core::fast { + +void CustomKernel::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + auto& s = stream(); + + for (auto& out : outputs) { + out.set_data(allocator::malloc_or_wait(out.nbytes())); + } + + std::vector copies; + + auto check_input = [&copies, &s, this](const array& x) -> const array { + bool no_copy = x.flags().row_contiguous; + if (!ensure_row_contiguous_ || no_copy) { + return x; + } else { + copies.push_back(array(x.shape(), x.dtype(), nullptr, {})); + copy_gpu(x, copies.back(), CopyType::General, s); + return copies.back(); + } + }; + std::vector checked_inputs; + for (const array& in : inputs) { + checked_inputs.push_back(check_input(in)); + } + + auto& d = metal::device(s.device); + const auto& lib_name = name_; + auto lib = d.get_library(lib_name); + if (lib == nullptr) { + lib = d.get_library(lib_name, metal::utils() + source_); + } + auto kernel = d.get_kernel(name_, lib); + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder->setComputePipelineState(kernel); + int index = 0; + for (int i = 0; i < checked_inputs.size(); i++) { + const array& in = checked_inputs[i]; + auto shape_info = shape_infos_[i]; + compute_encoder.set_input_array(in, index); + index++; + if (in.ndim() > 0) { + int ndim = in.ndim(); + if (shape_info.shape) { + set_vector_bytes(compute_encoder, in.shape(), ndim, index); + index++; + } + if (shape_info.strides) { + set_vector_bytes(compute_encoder, in.strides(), ndim, index); + index++; + } + if (shape_info.ndim) { + compute_encoder->setBytes(&ndim, sizeof(int), index); + index++; + } + } + } + for (array out : outputs) { + compute_encoder.set_output_array(out, index); + index++; + } + + const auto [tx, ty, tz] = threadgroup_; + MTL::Size group_dims = MTL::Size(tx, ty, tz); + const auto [gx, gy, gz] = grid_; + MTL::Size grid_dims = MTL::Size(gx, gy, gz); + compute_encoder->dispatchThreads(grid_dims, group_dims); + + if (!copies.empty()) { + d.get_command_buffer(s.index)->addCompletedHandler( + [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); + } +} + +} // namespace mlx::core::fast diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index 8ee6bb1e2..ac9e6ca68 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -119,6 +119,7 @@ NO_GPU_MULTI(RMSNormVJP) NO_GPU_MULTI(RoPE) NO_GPU(ScaledDotProductAttention) NO_GPU_MULTI(AffineQuantize) +NO_GPU_MULTI(CustomKernel) } // namespace fast } // namespace mlx::core diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 4a8cf791b..d06b4ce0e 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -1,7 +1,10 @@ // Copyright © 2023-2024 Apple Inc. #include +#include #include +#include +#include "mlx/backend/common/compiled.h" #include "mlx/fast.h" #include "mlx/fast_primitives.h" #include "mlx/ops.h" @@ -913,4 +916,271 @@ array affine_dequantize( return fallback({w, scales, biases})[0]; } +void validate_output_shapes( + std::map> output_shapes, + std::map output_dtypes) { + // Make sure output shapes and dtypes have the same keys + bool validated = true; + if (output_shapes.size() == 0) { + throw std::invalid_argument( + "[metal_kernel] Must specify at least one output."); + } + if (output_shapes.size() != output_dtypes.size()) { + validated = false; + } else { + for (const auto& kv : output_shapes) { + if (output_dtypes.find(kv.first) == output_dtypes.end()) { + validated = false; + break; + } + } + } + if (!validated) { + throw std::invalid_argument( + "[metal_kernel] `output_shapes` and `output_dtypes` must have the same keys."); + } +} + +void write_signature( + std::string func_name, + std::string& source, + std::map& inputs, + std::map>& output_shapes, + std::map& output_dtypes, + std::optional> template_args, + std::vector& shape_infos, + std::ostringstream& kernel_source) { + // Auto-generate a function signature based on `template_args` + // and the dtype/shape of the arrays passed as `inputs`. + if (template_args && template_args.value().size() > 0) { + kernel_source << "template <"; + int i = 0; + for (const auto& [name, arg] : template_args.value()) { + std::string param_type; + if (std::holds_alternative(arg)) { + param_type = "int"; + } else if (std::holds_alternative(arg)) { + param_type = "bool"; + } else if (std::holds_alternative(arg)) { + param_type = "typename"; + } + if (i > 0) { + kernel_source << ", "; + } + kernel_source << param_type << " " << name; + i++; + } + kernel_source << ">" << std::endl; + } + kernel_source << "[[kernel]] void " << func_name << "(" << std::endl; + + // Metal attributes are automatically added to the arguments if present + const std::vector> metal_attributes = { + {"dispatch_quadgroups_per_threadgroup", "uint"}, + {"dispatch_simdgroups_per_threadgroup", "uint"}, + {"dispatch_threads_per_threadgroup", "uint3"}, + {"grid_origin", "uint3"}, + {"grid_size", "uint3"}, + {"quadgroup_index_in_threadgroup", "uint"}, + {"quadgroups_per_threadgroup", "uint"}, + {"simdgroup_index_in_threadgroup", "uint"}, + {"simdgroups_per_threadgroup", "uint"}, + {"thread_execution_width", "uint"}, + {"thread_index_in_quadgroup", "uint"}, + {"thread_index_in_simdgroup", "uint"}, + {"thread_index_in_threadgroup", "uint"}, + {"thread_position_in_grid", "uint3"}, + {"thread_position_in_threadgroup", "uint3"}, + {"threadgroup_position_in_grid", "uint3"}, + {"threadgroups_per_grid", "uint3"}, + {"threads_per_grid", "uint3"}, + {"threads_per_simdgroup", "uint"}, + {"thread_per_threadgroup", "uint3"}, + }; + std::vector> attrs; + for (const auto& [attr, dtype] : metal_attributes) { + if (source.find(attr) != std::string::npos) { + attrs.push_back({attr, dtype}); + } + } + + int index = 0; + constexpr int max_constant_array_size = 8; + // Add inputs + for (const auto& [name, arr] : inputs) { + auto dtype = get_type_string(arr.dtype()); + bool is_constant = + arr.is_available() && arr.size() < max_constant_array_size; + std::string location = is_constant ? "constant" : "device"; + std::string ref = arr.ndim() == 0 ? "&" : "*"; + kernel_source << " const " << location << " " << dtype << ref << " " + << name << " [[buffer(" << index << ")]]," << std::endl; + index++; + // Add input shape, strides and ndim if present in the source + CustomKernelShapeInfo shape_info; + if (arr.ndim() > 0) { + if (source.find(name + "_shape") != std::string::npos) { + kernel_source << " const constant int* " << name << "_shape [[buffer(" + << index << ")]]," << std::endl; + shape_info.shape = true; + index++; + } + if (source.find(name + "_strides") != std::string::npos) { + kernel_source << " const constant size_t* " << name + << "_strides [[buffer(" << index << ")]]," << std::endl; + shape_info.strides = true; + index++; + } + if (source.find(name + "_ndim") != std::string::npos) { + kernel_source << " const constant int& " << name << "_ndim [[buffer(" + << index << ")]]," << std::endl; + shape_info.ndim = true; + index++; + } + } + shape_infos.push_back(shape_info); + } + // Add outputs + for (const auto& [name, dtype] : output_dtypes) { + kernel_source << " device " << get_type_string(dtype) << "* " << name + << " [[buffer(" << index << ")]]"; + if (index < inputs.size() + output_shapes.size() - 1 || attrs.size() > 0) { + kernel_source << "," << std::endl; + } else { + kernel_source << ") {" << std::endl; + } + index++; + } + // Add metal attributes e.g. `threadgroup_index_in_grid` + for (const auto& [attr, dtype] : attrs) { + kernel_source << " " << dtype << " " << attr << " [[" << attr << "]]"; + if (index < attrs.size() - 1) { + kernel_source << "," << std::endl; + } else { + kernel_source << ") {" << std::endl; + } + } + kernel_source << source << std::endl; + kernel_source << "}" << std::endl; +} + +std::string write_template(std::map& template_args) { + std::ostringstream template_def; + template_def << "<"; + int i = 0; + for (const auto& [name, arg] : template_args) { + if (i > 0) { + template_def << ", "; + } + if (std::holds_alternative(arg)) { + template_def << std::get(arg); + } else if (std::holds_alternative(arg)) { + template_def << std::get(arg); + } else if (std::holds_alternative(arg)) { + template_def << get_type_string(std::get(arg)); + } + i++; + } + template_def << ">"; + return template_def.str(); +} + +std::map MetalKernel::operator()( + std::map& inputs, + std::map> output_shapes, + std::map output_dtypes, + std::tuple grid, + std::tuple threadgroup, + std::optional> template_args, + bool verbose, + StreamOrDevice s_) { + validate_output_shapes(output_shapes, output_dtypes); + + auto s = to_stream(s_); + if (s.device != Device::gpu) { + throw std::invalid_argument( + "[metal_kernel] MetalKernel only works on GPU."); + } + + std::ostringstream kernel_source; + std::ostringstream func_name; + + std::string template_def = ""; + bool needs_template = template_args && template_args.value().size() > 0; + std::string hash_key = ""; + if (needs_template) { + std::regex disallowed_chars("\\<|\\>|(, )"); + template_def = write_template(template_args.value()); + hash_key = std::regex_replace(template_def, disallowed_chars, "_"); + hash_key.pop_back(); + } + + func_name << "custom_kernel_" << name_ << hash_key; + std::string kernel_name = func_name.str(); + + std::vector shape_infos; + write_signature( + func_name.str(), + source_, + inputs, + output_shapes, + output_dtypes, + template_args, + shape_infos, + kernel_source); + + if (needs_template) { + template_def = func_name.str() + template_def; + kernel_source << std::endl + << "template [[host_name(\"" << kernel_name + << "\")]] [[kernel]] decltype(" << template_def << ") " + << template_def << ";" << std::endl; + } + + if (verbose) { + std::cout << "Generated source code for `" << name_ << "`:" << std::endl + << "```" << std::endl + << kernel_source.str() << std::endl + << "```" << std::endl; + } + + std::vector in_arrs; + for (const auto& kv : inputs) { + in_arrs.push_back(kv.second); + } + + std::vector out_keys; + std::vector> out_shapes; + for (const auto& [name, shape] : output_shapes) { + out_keys.push_back(name); + out_shapes.push_back(shape); + } + + std::vector out_dtypes; + for (const auto& kv : output_dtypes) { + out_dtypes.push_back(kv.second); + } + + std::map outputs; + auto outputs_vec = array::make_arrays( + out_shapes, + out_dtypes, + std::make_shared( + s, + kernel_name, + kernel_source.str(), + grid, + threadgroup, + shape_infos, + ensure_row_contiguous_), + in_arrs); + + int i = 0; + for (const auto& key : out_keys) { + outputs.insert({key, outputs_vec[i]}); + i++; + } + return outputs; +} + } // namespace mlx::core::fast diff --git a/mlx/fast.h b/mlx/fast.h index 0274bf6dd..4b6e1e2c1 100644 --- a/mlx/fast.h +++ b/mlx/fast.h @@ -2,6 +2,7 @@ #pragma once +#include #include #include "mlx/utils.h" @@ -63,4 +64,32 @@ array affine_dequantize( int bits = 4, StreamOrDevice s = {}); +typedef std::variant TemplateArg; + +class MetalKernel { + public: + MetalKernel( + const std::string& name, + const std::string& source, + bool ensure_row_contiguous) + : name_(name), + source_(source), + ensure_row_contiguous_(ensure_row_contiguous) {} + + std::map operator()( + std::map& inputs, + std::map> output_shapes, + std::map output_dtypes, + std::tuple grid, + std::tuple threadgroup, + std::optional> template_args = + std::nullopt, + bool verbose = false, + StreamOrDevice s = {}); + + private: + std::string name_; + std::string source_; + bool ensure_row_contiguous_ = true; +}; } // namespace mlx::core::fast diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index 1883f5789..0039ff01a 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -242,4 +242,47 @@ class AffineQuantize : public Custom { bool dequantize_; }; +struct CustomKernelShapeInfo { + bool shape = false; + bool strides = false; + bool ndim = false; +}; + +class CustomKernel : public Primitive { + public: + CustomKernel( + Stream stream, + std::string name, + std::string source, + std::tuple grid, + std::tuple threadgroup, + std::vector shape_infos, + bool ensure_row_contiguous) + : Primitive(stream), + source_(source), + name_(name), + grid_(grid), + threadgroup_(threadgroup), + shape_infos_(shape_infos), + ensure_row_contiguous_(ensure_row_contiguous) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override { + throw std::runtime_error("Custom Metal kernels only run on GPU."); + } + + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_PRINT(CustomKernel); + + private: + std::string source_; + std::string name_; + std::tuple grid_; + std::tuple threadgroup_; + std::vector shape_infos_; + bool ensure_row_contiguous_; +}; + } // namespace mlx::core::fast diff --git a/python/src/fast.cpp b/python/src/fast.cpp index 4c25c7ac7..92e0232f6 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -1,9 +1,14 @@ // Copyright © 2023-2024 Apple Inc. #include +#include #include +#include #include #include +#include + +#include "python/src/utils.h" #include "mlx/fast.h" #include "mlx/ops.h" @@ -186,4 +191,136 @@ void init_fast(nb::module_& parent_module) { Returns: array: The quantized version of ``w`` )pbdoc"); + + nb::class_( + m, + "metal_kernel", + R"pbdoc( + A jit-compiled custom Metal kernel defined from a source string. + )pbdoc") + .def( + nb::init(), + "name"_a, + "source"_a, + "ensure_row_contiguous"_a = true, + R"pbdoc( + Initialize a metal_kernel. + + Args: + name (str): Name for the kernel. + source (str): Source code. This is the body of a function in Metal, + the function signature will be generated for you. The names of the inputs/outputs + are determined by the ``inputs`` and ``output_shapes``/``output_dtypes`` + used when the kernel is called. + ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous + before the kernel runs. Default: ``True``. + Returns: + Callable ``metal_kernel``. + + .. code-block:: python + + def exp_elementwise(a: mx.array): + source = """ + uint elem = thread_position_in_grid.x; + T tmp = inp[elem]; + out[elem] = metal::exp(tmp); + """ + + kernel = mx.fast.metal_kernel( + name="myexp", + source=source + ) + outputs = kernel( + inputs={"inp": a}, + template={"T": mx.float32}, + grid=(a.size, 1, 1), + threadgroup=(256, 1, 1), + output_shapes={"out": a.shape}, + output_dtypes={"out": a.dtype}, + verbose=True, + ) + return outputs["out"] + + a = mx.random.normal(shape=(4, 16)).astype(mx.float16) + b = exp_elementwise(a) + assert mx.allclose(b, mx.exp(a)) + + )pbdoc") + .def( + "__call__", + [](fast::MetalKernel& kernel, + std::map& inputs_, + std::map>& output_shapes, + std::map& output_dtypes, + std::tuple grid, + std::tuple threadgroup, + std::optional> template_args_, + bool verbose, + StreamOrDevice s) { + std::map inputs; + for (const auto& [name, value] : inputs_) { + auto arr = to_array(value, std::nullopt); + inputs.insert({name, arr}); + } + std::map template_args; + if (template_args_) { + for (const auto& [name, value] : template_args_.value()) { + // Handle bool, int and dtype template args + if (nb::isinstance(value)) { + bool bool_val = nb::cast(value); + template_args.insert({name, bool_val}); + } else if (nb::isinstance(value)) { + int int_val = nb::cast(value); + template_args.insert({name, int_val}); + } else if (nb::isinstance(value)) { + Dtype dtype = nb::cast(value); + template_args.insert({name, dtype}); + } else { + throw std::invalid_argument( + "[[metal_kernel]] Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`."); + } + } + } + return kernel( + inputs, + output_shapes, + output_dtypes, + grid, + threadgroup, + template_args, + verbose, + s); + }, + nb::kw_only(), + "inputs"_a, + "output_shapes"_a, + "output_dtypes"_a, + "grid"_a, + "threadgroup"_a, + "template"_a = nb::none(), + "verbose"_a = false, + "stream"_a = nb::none(), + nb::sig( + "def __call__(self, *, inputs: Mapping[str, Union[scalar, array]], output_shapes: Mapping[str, Sequence[int]], output_dtypes: Mapping[str, Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[Mapping[str, Union[bool, int, Dtype]]] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"), + R"pbdoc( + Run the kernel. + + Args: + inputs (Mapping[str, array]): Inputs. These will be added to the function signature and passed to the Metal kernel. + The keys will be the names of the arguments to the kernel. + output_shapes (Mapping[str, Sequence[int]]): Output shapes. A dict mapping + output variable names to shapes. These will be added to the function signature. + output_dtypes (Mapping[str, Dtype]): Output dtypes. A dict mapping output variable + names to dtypes. Must have the same keys as ``output_shapes``. + grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with. + threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use. + template (Mapping[str, Union[bool, int, Dtype]], optional): Template arguments. + These will be added as template arguments to the kernel definition. + verbose (bool, optional): Whether to print the full generated source code of the kernel + when it is run. + stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``. + + Returns: + dict[str, array]: Dictionary of output arrays based on ``output_shapes``/``output_dtypes``. + )pbdoc"); } diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp index dd79e44d6..b17ba185e 100644 --- a/python/src/linalg.cpp +++ b/python/src/linalg.cpp @@ -325,9 +325,9 @@ void init_linalg(nb::module_& parent_module) { nb::sig( "def cholesky_inv(L: array, upper: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - Compute the inverse of a real symmetric positive semi-definite matrix using it's Cholesky decomposition L. + Compute the inverse of a real symmetric positive semi-definite matrix using it's Cholesky decomposition. - Let A be a real symmetric positive semi-definite matrix and L its Cholesky definition such that: + Let :math:`\mathbf{A}` be a real symmetric positive semi-definite matrix and :math:`\mathbf{L}` its Cholesky decomposition such that: .. math:: @@ -339,7 +339,7 @@ void init_linalg(nb::module_& parent_module) { This function supports arrays with at least 2 dimensions. When the input has more than two dimensions, the Cholesky inverse is computed for each matrix - in the last two dimensions of ``L``. + in the last two dimensions of :math:`\mathbf{L}`. If the input matrix is not a triangular matrix behaviour is undefined. @@ -351,6 +351,6 @@ void init_linalg(nb::module_& parent_module) { in which case the default stream of the default device is used. Returns: - array: :math:`A^{-1}` where :math:`\mathbf{A} = \mathbf{L}\mathbf{L}^T`. + array: :math:`\mathbf{A^{-1}}` where :math:`\mathbf{A} = \mathbf{L}\mathbf{L}^T`. )pbdoc"); } diff --git a/python/tests/test_fast.py b/python/tests/test_fast.py index b15b737b3..121d5d482 100644 --- a/python/tests/test_fast.py +++ b/python/tests/test_fast.py @@ -548,6 +548,104 @@ class TestFast(mlx_tests.MLXTestCase): ) self.assertTrue(mx.allclose(w, w_p)) + @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") + def test_custom_kernel_basic(self): + mx.random.seed(7) + a = mx.random.normal(shape=(3, 6)) + kernel = mx.fast.metal_kernel( + name="basic", + source=""" + uint elem = thread_position_in_grid.x; + out1[elem] = a[elem]; + """, + ) + out = kernel( + inputs={"a": a}, + grid=(4, 1, 1), + threadgroup=(2, 1, 1), + output_shapes={"out1": (2, 2)}, + output_dtypes={"out1": mx.float32}, + stream=mx.gpu, + ) + mx.allclose(out["out1"], a[:2, :2]) + + @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") + def test_custom_kernel_args(self): + mx.random.seed(7) + a = mx.random.normal(shape=(3, 6)) + c = mx.random.normal(shape=(2, 2)).astype(mx.bfloat16) + + kernel = mx.fast.metal_kernel( + name="arg_test", + source=""" + uint elem = thread_position_in_grid.x; + T tmp = a[0]; + if (e) { + out1[elem] = a[1] + b[2] + c[3] + d + f; + } else { + out1[elem] = 1; + } + out2[elem] = a[1] + b[2] + c[1] - d; + """, + ) + out = kernel( + inputs={ + "a": a, + "b": mx.array([3, 4, 5]), + "c": c, + "d": 7.3, + }, + template={ + "e": True, + "f": 3, + "T": mx.float16, + }, + grid=(6, 1, 1), + threadgroup=(2, 1, 1), + output_shapes={"out1": (2, 2), "out2": (3, 2)}, + output_dtypes={"out1": mx.float32, "out2": mx.int32}, + stream=mx.gpu, + ) + + self.assertTrue(mx.allclose(out["out1"], mx.full((2, 2), 14.0484))) + self.assertTrue(mx.allclose(out["out2"], mx.full((3, 2), -2, dtype=mx.int32))) + + @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") + def test_custom_kernel_strides(self): + mx.random.seed(7) + a = mx.random.normal(shape=(3, 6)) + source = """ + uint elem = thread_position_in_grid.x; + uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim); + T tmp = inp[loc]; + out[elem] = metal::exp(tmp); + """ + source_contig = """ + uint elem = thread_position_in_grid.x; + T tmp = inp[elem]; + out[elem] = metal::exp(tmp); + """ + + # non contiguous + a = mx.tile(a[::2], [4, 1]) + + for contig in [True, False]: + kernel = mx.fast.metal_kernel( + name="myexp" + str(contig), + source=source_contig if contig else source, + ensure_row_contiguous=contig, + ) + outputs = kernel( + inputs={"inp": a}, + template={"T": mx.float32}, + grid=(a.size, 1, 1), + threadgroup=(256, 1, 1), + output_shapes={"out": a.shape}, + output_dtypes={"out": a.dtype}, + stream=mx.gpu, + ) + self.assertTrue(mx.allclose(mx.exp(a), outputs["out"])) + if __name__ == "__main__": unittest.main()