Custom Metal Kernels from Python (#1325)

* start

* simple kernels working

* restructure

* inverse example working

* docs + fixes

* missing file

* fix imports

* address comments

* add docs + fix test

* Review comments + refactor to a single function

* update docs

* remove hashing

* fix contig bug in test

* back to a class

* trailing whitespace

* fix tests

* match c++ and python apis

* add link + make args kw_only
This commit is contained in:
Alex Barron 2024-08-22 13:46:29 -07:00 committed by GitHub
parent df3233454d
commit 0fd2a1f4b0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 793 additions and 4 deletions

View File

@ -0,0 +1,123 @@
Custom Metal Kernels
====================
MLX supports writing custom Metal kernels through the Python and C++ APIs.
Simple Example
--------------
Let's write a custom kernel that computes ``exp`` elementwise:
.. code-block:: python
def exp_elementwise(a: mx.array):
source = """
uint elem = thread_position_in_grid.x;
T tmp = inp[elem];
out[elem] = metal::exp(tmp);
"""
kernel = mx.fast.metal_kernel(
name="myexp",
source=source,
)
outputs = kernel(
inputs={"inp": a},
template={"T": mx.float32},
grid=(a.size, 1, 1),
threadgroup=(256, 1, 1),
output_shapes={"out": a.shape},
output_dtypes={"out": a.dtype},
)
return outputs["out"]
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
b = exp_elementwise(a)
assert mx.allclose(b, mx.exp(a))
.. note::
We are only required to pass the body of the Metal kernel in ``source``.
The full function signature will be generated using:
* The keys and shapes/dtypes of ``inputs``
In the above, ``a`` is an ``mx.array`` of type ``mx.float16`` and we pass it with the key ``inp``
so we will add ``const device float16_t* inp`` to the signature.
``inp_shape``, ``inp_strides`` and ``inp_ndim`` are also added for convenience.
* The keys and values of ``output_shapes`` and ``output_dtypes``
In the above, ``out`` is an ``mx.array`` of type ``mx.float16``
so we add ``device float16_t* out``.
* Template parameters passed using ``template``
In the above, ``template={"T": mx.float32}`` adds a template of ``template <typename T>`` to the function
and instantiates the template with ``custom_kernel_myexp_float<float>``.
Template parameters can be ``mx.core.Dtype``, ``int`` or ``bool``.
* Metal attributes used in ``source`` such as ``[[thread_position_in_grid]]``
These will be added as function arguments.
All the attributes defined in Table 5.8 of the `Metal Shading Language Specification <https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf>`_ are supported.
Putting this all together, the generated function signature for ``myexp`` is as follows:
.. code-block:: cpp
template <typename T>
[[kernel]] void custom_kernel_myexp_float(
const device float16_t* inp [[buffer(0)]],
device float16_t* out [[buffer(1)]],
uint3 thread_position_in_grid [[thread_position_in_grid]]) {
uint elem = thread_position_in_grid.x;
T tmp = inp[elem];
out[elem] = metal::exp(tmp);
}
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
You can print the generated code for a ``mx.fast.metal_kernel`` by passing ``verbose=True`` when you call it.
Using Shape/Strides
-------------------
``mx.fast.metal_kernel`` supports an argument ``ensure_row_contiguous`` which is ``True`` by default.
This will copy the ``mx.array`` inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous.
Generally this makes writing the kernel easier, since we don't have to worry about gaps or the ordering of the dims
when indexing.
If we want to avoid this copy, ``metal_kernel`` automatically passes ``a_shape``, ``a_strides`` and ``a_ndim`` for each
input array ``a`` if any are present in ``source``.
We can then use MLX's built in indexing utils to fetch the right elements for each thread.
Let's convert ``myexp`` above to support arbitrarily strided arrays without relying on a copy from ``ensure_row_contiguous``:
.. code-block:: python
def exp_elementwise(a: mx.array):
source = """
uint elem = thread_position_in_grid.x;
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
T tmp = inp[loc];
// Output arrays are always row contiguous
out[elem] = metal::exp(tmp);
"""
kernel = mx.fast.metal_kernel(
name="myexp_strided",
source=source
)
outputs = kernel(
inputs={"inp": a},
template={"T": mx.float32},
grid=(a.size, 1, 1),
threadgroup=(256, 1, 1),
output_shapes={"out": a.shape},
output_dtypes={"out": a.dtype},
ensure_row_contiguous=False,
)
return outputs["out"]
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
# make non-contiguous
a = a[::2]
b = exp_elementwise(a)
assert mx.allclose(b, mx.exp(a))

View File

@ -85,3 +85,4 @@ are the CPU and GPU.
dev/extensions
dev/metal_debugger
dev/custom_metal_kernels

View File

@ -12,3 +12,5 @@ Fast
layer_norm
rope
scaled_dot_product_attention
affine_quantize
metal_kernel

View File

@ -131,6 +131,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp

View File

@ -0,0 +1,84 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/fast_primitives.h"
namespace mlx::core::fast {
void CustomKernel::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& s = stream();
for (auto& out : outputs) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
std::vector<array> copies;
auto check_input = [&copies, &s, this](const array& x) -> const array {
bool no_copy = x.flags().row_contiguous;
if (!ensure_row_contiguous_ || no_copy) {
return x;
} else {
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
copy_gpu(x, copies.back(), CopyType::General, s);
return copies.back();
}
};
std::vector<const array> checked_inputs;
for (const array& in : inputs) {
checked_inputs.push_back(check_input(in));
}
auto& d = metal::device(s.device);
const auto& lib_name = name_;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
lib = d.get_library(lib_name, metal::utils() + source_);
}
auto kernel = d.get_kernel(name_, lib);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
int index = 0;
for (int i = 0; i < checked_inputs.size(); i++) {
const array& in = checked_inputs[i];
auto shape_info = shape_infos_[i];
compute_encoder.set_input_array(in, index);
index++;
if (in.ndim() > 0) {
int ndim = in.ndim();
if (shape_info.shape) {
set_vector_bytes(compute_encoder, in.shape(), ndim, index);
index++;
}
if (shape_info.strides) {
set_vector_bytes(compute_encoder, in.strides(), ndim, index);
index++;
}
if (shape_info.ndim) {
compute_encoder->setBytes(&ndim, sizeof(int), index);
index++;
}
}
}
for (array out : outputs) {
compute_encoder.set_output_array(out, index);
index++;
}
const auto [tx, ty, tz] = threadgroup_;
MTL::Size group_dims = MTL::Size(tx, ty, tz);
const auto [gx, gy, gz] = grid_;
MTL::Size grid_dims = MTL::Size(gx, gy, gz);
compute_encoder->dispatchThreads(grid_dims, group_dims);
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
}
}
} // namespace mlx::core::fast

View File

@ -119,6 +119,7 @@ NO_GPU_MULTI(RMSNormVJP)
NO_GPU_MULTI(RoPE)
NO_GPU(ScaledDotProductAttention)
NO_GPU_MULTI(AffineQuantize)
NO_GPU_MULTI(CustomKernel)
} // namespace fast
} // namespace mlx::core

View File

@ -1,7 +1,10 @@
// Copyright © 2023-2024 Apple Inc.
#include <cassert>
#include <iostream>
#include <numeric>
#include <regex>
#include "mlx/backend/common/compiled.h"
#include "mlx/fast.h"
#include "mlx/fast_primitives.h"
#include "mlx/ops.h"
@ -913,4 +916,271 @@ array affine_dequantize(
return fallback({w, scales, biases})[0];
}
void validate_output_shapes(
std::map<std::string, std::vector<int>> output_shapes,
std::map<std::string, Dtype> output_dtypes) {
// Make sure output shapes and dtypes have the same keys
bool validated = true;
if (output_shapes.size() == 0) {
throw std::invalid_argument(
"[metal_kernel] Must specify at least one output.");
}
if (output_shapes.size() != output_dtypes.size()) {
validated = false;
} else {
for (const auto& kv : output_shapes) {
if (output_dtypes.find(kv.first) == output_dtypes.end()) {
validated = false;
break;
}
}
}
if (!validated) {
throw std::invalid_argument(
"[metal_kernel] `output_shapes` and `output_dtypes` must have the same keys.");
}
}
void write_signature(
std::string func_name,
std::string& source,
std::map<std::string, array>& inputs,
std::map<std::string, std::vector<int>>& output_shapes,
std::map<std::string, Dtype>& output_dtypes,
std::optional<std::map<std::string, TemplateArg>> template_args,
std::vector<CustomKernelShapeInfo>& shape_infos,
std::ostringstream& kernel_source) {
// Auto-generate a function signature based on `template_args`
// and the dtype/shape of the arrays passed as `inputs`.
if (template_args && template_args.value().size() > 0) {
kernel_source << "template <";
int i = 0;
for (const auto& [name, arg] : template_args.value()) {
std::string param_type;
if (std::holds_alternative<int>(arg)) {
param_type = "int";
} else if (std::holds_alternative<bool>(arg)) {
param_type = "bool";
} else if (std::holds_alternative<Dtype>(arg)) {
param_type = "typename";
}
if (i > 0) {
kernel_source << ", ";
}
kernel_source << param_type << " " << name;
i++;
}
kernel_source << ">" << std::endl;
}
kernel_source << "[[kernel]] void " << func_name << "(" << std::endl;
// Metal attributes are automatically added to the arguments if present
const std::vector<std::pair<std::string, std::string>> metal_attributes = {
{"dispatch_quadgroups_per_threadgroup", "uint"},
{"dispatch_simdgroups_per_threadgroup", "uint"},
{"dispatch_threads_per_threadgroup", "uint3"},
{"grid_origin", "uint3"},
{"grid_size", "uint3"},
{"quadgroup_index_in_threadgroup", "uint"},
{"quadgroups_per_threadgroup", "uint"},
{"simdgroup_index_in_threadgroup", "uint"},
{"simdgroups_per_threadgroup", "uint"},
{"thread_execution_width", "uint"},
{"thread_index_in_quadgroup", "uint"},
{"thread_index_in_simdgroup", "uint"},
{"thread_index_in_threadgroup", "uint"},
{"thread_position_in_grid", "uint3"},
{"thread_position_in_threadgroup", "uint3"},
{"threadgroup_position_in_grid", "uint3"},
{"threadgroups_per_grid", "uint3"},
{"threads_per_grid", "uint3"},
{"threads_per_simdgroup", "uint"},
{"thread_per_threadgroup", "uint3"},
};
std::vector<std::pair<std::string, std::string>> attrs;
for (const auto& [attr, dtype] : metal_attributes) {
if (source.find(attr) != std::string::npos) {
attrs.push_back({attr, dtype});
}
}
int index = 0;
constexpr int max_constant_array_size = 8;
// Add inputs
for (const auto& [name, arr] : inputs) {
auto dtype = get_type_string(arr.dtype());
bool is_constant =
arr.is_available() && arr.size() < max_constant_array_size;
std::string location = is_constant ? "constant" : "device";
std::string ref = arr.ndim() == 0 ? "&" : "*";
kernel_source << " const " << location << " " << dtype << ref << " "
<< name << " [[buffer(" << index << ")]]," << std::endl;
index++;
// Add input shape, strides and ndim if present in the source
CustomKernelShapeInfo shape_info;
if (arr.ndim() > 0) {
if (source.find(name + "_shape") != std::string::npos) {
kernel_source << " const constant int* " << name << "_shape [[buffer("
<< index << ")]]," << std::endl;
shape_info.shape = true;
index++;
}
if (source.find(name + "_strides") != std::string::npos) {
kernel_source << " const constant size_t* " << name
<< "_strides [[buffer(" << index << ")]]," << std::endl;
shape_info.strides = true;
index++;
}
if (source.find(name + "_ndim") != std::string::npos) {
kernel_source << " const constant int& " << name << "_ndim [[buffer("
<< index << ")]]," << std::endl;
shape_info.ndim = true;
index++;
}
}
shape_infos.push_back(shape_info);
}
// Add outputs
for (const auto& [name, dtype] : output_dtypes) {
kernel_source << " device " << get_type_string(dtype) << "* " << name
<< " [[buffer(" << index << ")]]";
if (index < inputs.size() + output_shapes.size() - 1 || attrs.size() > 0) {
kernel_source << "," << std::endl;
} else {
kernel_source << ") {" << std::endl;
}
index++;
}
// Add metal attributes e.g. `threadgroup_index_in_grid`
for (const auto& [attr, dtype] : attrs) {
kernel_source << " " << dtype << " " << attr << " [[" << attr << "]]";
if (index < attrs.size() - 1) {
kernel_source << "," << std::endl;
} else {
kernel_source << ") {" << std::endl;
}
}
kernel_source << source << std::endl;
kernel_source << "}" << std::endl;
}
std::string write_template(std::map<std::string, TemplateArg>& template_args) {
std::ostringstream template_def;
template_def << "<";
int i = 0;
for (const auto& [name, arg] : template_args) {
if (i > 0) {
template_def << ", ";
}
if (std::holds_alternative<int>(arg)) {
template_def << std::get<int>(arg);
} else if (std::holds_alternative<bool>(arg)) {
template_def << std::get<bool>(arg);
} else if (std::holds_alternative<Dtype>(arg)) {
template_def << get_type_string(std::get<Dtype>(arg));
}
i++;
}
template_def << ">";
return template_def.str();
}
std::map<std::string, array> MetalKernel::operator()(
std::map<std::string, array>& inputs,
std::map<std::string, std::vector<int>> output_shapes,
std::map<std::string, Dtype> output_dtypes,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
std::optional<std::map<std::string, TemplateArg>> template_args,
bool verbose,
StreamOrDevice s_) {
validate_output_shapes(output_shapes, output_dtypes);
auto s = to_stream(s_);
if (s.device != Device::gpu) {
throw std::invalid_argument(
"[metal_kernel] MetalKernel only works on GPU.");
}
std::ostringstream kernel_source;
std::ostringstream func_name;
std::string template_def = "";
bool needs_template = template_args && template_args.value().size() > 0;
std::string hash_key = "";
if (needs_template) {
std::regex disallowed_chars("\\<|\\>|(, )");
template_def = write_template(template_args.value());
hash_key = std::regex_replace(template_def, disallowed_chars, "_");
hash_key.pop_back();
}
func_name << "custom_kernel_" << name_ << hash_key;
std::string kernel_name = func_name.str();
std::vector<CustomKernelShapeInfo> shape_infos;
write_signature(
func_name.str(),
source_,
inputs,
output_shapes,
output_dtypes,
template_args,
shape_infos,
kernel_source);
if (needs_template) {
template_def = func_name.str() + template_def;
kernel_source << std::endl
<< "template [[host_name(\"" << kernel_name
<< "\")]] [[kernel]] decltype(" << template_def << ") "
<< template_def << ";" << std::endl;
}
if (verbose) {
std::cout << "Generated source code for `" << name_ << "`:" << std::endl
<< "```" << std::endl
<< kernel_source.str() << std::endl
<< "```" << std::endl;
}
std::vector<array> in_arrs;
for (const auto& kv : inputs) {
in_arrs.push_back(kv.second);
}
std::vector<std::string> out_keys;
std::vector<std::vector<int>> out_shapes;
for (const auto& [name, shape] : output_shapes) {
out_keys.push_back(name);
out_shapes.push_back(shape);
}
std::vector<Dtype> out_dtypes;
for (const auto& kv : output_dtypes) {
out_dtypes.push_back(kv.second);
}
std::map<std::string, array> outputs;
auto outputs_vec = array::make_arrays(
out_shapes,
out_dtypes,
std::make_shared<CustomKernel>(
s,
kernel_name,
kernel_source.str(),
grid,
threadgroup,
shape_infos,
ensure_row_contiguous_),
in_arrs);
int i = 0;
for (const auto& key : out_keys) {
outputs.insert({key, outputs_vec[i]});
i++;
}
return outputs;
}
} // namespace mlx::core::fast

View File

@ -2,6 +2,7 @@
#pragma once
#include <map>
#include <optional>
#include "mlx/utils.h"
@ -63,4 +64,32 @@ array affine_dequantize(
int bits = 4,
StreamOrDevice s = {});
typedef std::variant<int, bool, Dtype> TemplateArg;
class MetalKernel {
public:
MetalKernel(
const std::string& name,
const std::string& source,
bool ensure_row_contiguous)
: name_(name),
source_(source),
ensure_row_contiguous_(ensure_row_contiguous) {}
std::map<std::string, array> operator()(
std::map<std::string, array>& inputs,
std::map<std::string, std::vector<int>> output_shapes,
std::map<std::string, Dtype> output_dtypes,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
std::optional<std::map<std::string, TemplateArg>> template_args =
std::nullopt,
bool verbose = false,
StreamOrDevice s = {});
private:
std::string name_;
std::string source_;
bool ensure_row_contiguous_ = true;
};
} // namespace mlx::core::fast

View File

@ -242,4 +242,47 @@ class AffineQuantize : public Custom {
bool dequantize_;
};
struct CustomKernelShapeInfo {
bool shape = false;
bool strides = false;
bool ndim = false;
};
class CustomKernel : public Primitive {
public:
CustomKernel(
Stream stream,
std::string name,
std::string source,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
std::vector<CustomKernelShapeInfo> shape_infos,
bool ensure_row_contiguous)
: Primitive(stream),
source_(source),
name_(name),
grid_(grid),
threadgroup_(threadgroup),
shape_infos_(shape_infos),
ensure_row_contiguous_(ensure_row_contiguous) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
throw std::runtime_error("Custom Metal kernels only run on GPU.");
}
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
DEFINE_PRINT(CustomKernel);
private:
std::string source_;
std::string name_;
std::tuple<int, int, int> grid_;
std::tuple<int, int, int> threadgroup_;
std::vector<CustomKernelShapeInfo> shape_infos_;
bool ensure_row_contiguous_;
};
} // namespace mlx::core::fast

View File

@ -1,9 +1,14 @@
// Copyright © 2023-2024 Apple Inc.
#include <nanobind/nanobind.h>
#include <nanobind/stl/map.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/tuple.h>
#include <nanobind/stl/variant.h>
#include <nanobind/stl/vector.h>
#include "python/src/utils.h"
#include "mlx/fast.h"
#include "mlx/ops.h"
@ -186,4 +191,136 @@ void init_fast(nb::module_& parent_module) {
Returns:
array: The quantized version of ``w``
)pbdoc");
nb::class_<fast::MetalKernel>(
m,
"metal_kernel",
R"pbdoc(
A jit-compiled custom Metal kernel defined from a source string.
)pbdoc")
.def(
nb::init<const std::string&, const std::string&, bool>(),
"name"_a,
"source"_a,
"ensure_row_contiguous"_a = true,
R"pbdoc(
Initialize a metal_kernel.
Args:
name (str): Name for the kernel.
source (str): Source code. This is the body of a function in Metal,
the function signature will be generated for you. The names of the inputs/outputs
are determined by the ``inputs`` and ``output_shapes``/``output_dtypes``
used when the kernel is called.
ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous
before the kernel runs. Default: ``True``.
Returns:
Callable ``metal_kernel``.
.. code-block:: python
def exp_elementwise(a: mx.array):
source = """
uint elem = thread_position_in_grid.x;
T tmp = inp[elem];
out[elem] = metal::exp(tmp);
"""
kernel = mx.fast.metal_kernel(
name="myexp",
source=source
)
outputs = kernel(
inputs={"inp": a},
template={"T": mx.float32},
grid=(a.size, 1, 1),
threadgroup=(256, 1, 1),
output_shapes={"out": a.shape},
output_dtypes={"out": a.dtype},
verbose=True,
)
return outputs["out"]
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
b = exp_elementwise(a)
assert mx.allclose(b, mx.exp(a))
)pbdoc")
.def(
"__call__",
[](fast::MetalKernel& kernel,
std::map<std::string, ScalarOrArray>& inputs_,
std::map<std::string, std::vector<int>>& output_shapes,
std::map<std::string, Dtype>& output_dtypes,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
std::optional<std::map<std::string, nb::handle>> template_args_,
bool verbose,
StreamOrDevice s) {
std::map<std::string, array> inputs;
for (const auto& [name, value] : inputs_) {
auto arr = to_array(value, std::nullopt);
inputs.insert({name, arr});
}
std::map<std::string, fast::TemplateArg> template_args;
if (template_args_) {
for (const auto& [name, value] : template_args_.value()) {
// Handle bool, int and dtype template args
if (nb::isinstance<bool>(value)) {
bool bool_val = nb::cast<bool>(value);
template_args.insert({name, bool_val});
} else if (nb::isinstance<int>(value)) {
int int_val = nb::cast<int>(value);
template_args.insert({name, int_val});
} else if (nb::isinstance<Dtype>(value)) {
Dtype dtype = nb::cast<Dtype>(value);
template_args.insert({name, dtype});
} else {
throw std::invalid_argument(
"[[metal_kernel]] Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`.");
}
}
}
return kernel(
inputs,
output_shapes,
output_dtypes,
grid,
threadgroup,
template_args,
verbose,
s);
},
nb::kw_only(),
"inputs"_a,
"output_shapes"_a,
"output_dtypes"_a,
"grid"_a,
"threadgroup"_a,
"template"_a = nb::none(),
"verbose"_a = false,
"stream"_a = nb::none(),
nb::sig(
"def __call__(self, *, inputs: Mapping[str, Union[scalar, array]], output_shapes: Mapping[str, Sequence[int]], output_dtypes: Mapping[str, Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[Mapping[str, Union[bool, int, Dtype]]] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"),
R"pbdoc(
Run the kernel.
Args:
inputs (Mapping[str, array]): Inputs. These will be added to the function signature and passed to the Metal kernel.
The keys will be the names of the arguments to the kernel.
output_shapes (Mapping[str, Sequence[int]]): Output shapes. A dict mapping
output variable names to shapes. These will be added to the function signature.
output_dtypes (Mapping[str, Dtype]): Output dtypes. A dict mapping output variable
names to dtypes. Must have the same keys as ``output_shapes``.
grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.
threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.
template (Mapping[str, Union[bool, int, Dtype]], optional): Template arguments.
These will be added as template arguments to the kernel definition.
verbose (bool, optional): Whether to print the full generated source code of the kernel
when it is run.
stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``.
Returns:
dict[str, array]: Dictionary of output arrays based on ``output_shapes``/``output_dtypes``.
)pbdoc");
}

View File

@ -325,9 +325,9 @@ void init_linalg(nb::module_& parent_module) {
nb::sig(
"def cholesky_inv(L: array, upper: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Compute the inverse of a real symmetric positive semi-definite matrix using it's Cholesky decomposition L.
Compute the inverse of a real symmetric positive semi-definite matrix using it's Cholesky decomposition.
Let A be a real symmetric positive semi-definite matrix and L its Cholesky definition such that:
Let :math:`\mathbf{A}` be a real symmetric positive semi-definite matrix and :math:`\mathbf{L}` its Cholesky decomposition such that:
.. math::
@ -339,7 +339,7 @@ void init_linalg(nb::module_& parent_module) {
This function supports arrays with at least 2 dimensions. When the input
has more than two dimensions, the Cholesky inverse is computed for each matrix
in the last two dimensions of ``L``.
in the last two dimensions of :math:`\mathbf{L}`.
If the input matrix is not a triangular matrix behaviour is undefined.
@ -351,6 +351,6 @@ void init_linalg(nb::module_& parent_module) {
in which case the default stream of the default device is used.
Returns:
array: :math:`A^{-1}` where :math:`\mathbf{A} = \mathbf{L}\mathbf{L}^T`.
array: :math:`\mathbf{A^{-1}}` where :math:`\mathbf{A} = \mathbf{L}\mathbf{L}^T`.
)pbdoc");
}

View File

@ -548,6 +548,104 @@ class TestFast(mlx_tests.MLXTestCase):
)
self.assertTrue(mx.allclose(w, w_p))
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_custom_kernel_basic(self):
mx.random.seed(7)
a = mx.random.normal(shape=(3, 6))
kernel = mx.fast.metal_kernel(
name="basic",
source="""
uint elem = thread_position_in_grid.x;
out1[elem] = a[elem];
""",
)
out = kernel(
inputs={"a": a},
grid=(4, 1, 1),
threadgroup=(2, 1, 1),
output_shapes={"out1": (2, 2)},
output_dtypes={"out1": mx.float32},
stream=mx.gpu,
)
mx.allclose(out["out1"], a[:2, :2])
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_custom_kernel_args(self):
mx.random.seed(7)
a = mx.random.normal(shape=(3, 6))
c = mx.random.normal(shape=(2, 2)).astype(mx.bfloat16)
kernel = mx.fast.metal_kernel(
name="arg_test",
source="""
uint elem = thread_position_in_grid.x;
T tmp = a[0];
if (e) {
out1[elem] = a[1] + b[2] + c[3] + d + f;
} else {
out1[elem] = 1;
}
out2[elem] = a[1] + b[2] + c[1] - d;
""",
)
out = kernel(
inputs={
"a": a,
"b": mx.array([3, 4, 5]),
"c": c,
"d": 7.3,
},
template={
"e": True,
"f": 3,
"T": mx.float16,
},
grid=(6, 1, 1),
threadgroup=(2, 1, 1),
output_shapes={"out1": (2, 2), "out2": (3, 2)},
output_dtypes={"out1": mx.float32, "out2": mx.int32},
stream=mx.gpu,
)
self.assertTrue(mx.allclose(out["out1"], mx.full((2, 2), 14.0484)))
self.assertTrue(mx.allclose(out["out2"], mx.full((3, 2), -2, dtype=mx.int32)))
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_custom_kernel_strides(self):
mx.random.seed(7)
a = mx.random.normal(shape=(3, 6))
source = """
uint elem = thread_position_in_grid.x;
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
T tmp = inp[loc];
out[elem] = metal::exp(tmp);
"""
source_contig = """
uint elem = thread_position_in_grid.x;
T tmp = inp[elem];
out[elem] = metal::exp(tmp);
"""
# non contiguous
a = mx.tile(a[::2], [4, 1])
for contig in [True, False]:
kernel = mx.fast.metal_kernel(
name="myexp" + str(contig),
source=source_contig if contig else source,
ensure_row_contiguous=contig,
)
outputs = kernel(
inputs={"inp": a},
template={"T": mx.float32},
grid=(a.size, 1, 1),
threadgroup=(256, 1, 1),
output_shapes={"out": a.shape},
output_dtypes={"out": a.dtype},
stream=mx.gpu,
)
self.assertTrue(mx.allclose(mx.exp(a), outputs["out"]))
if __name__ == "__main__":
unittest.main()