Simplifications for MLX C (#1396)

* simplifications for MLX C

* use vectors instead of map

* update examples
This commit is contained in:
Awni Hannun 2024-09-06 19:16:50 -07:00 committed by GitHub
parent 7cca1727af
commit ba3e913c7a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 334 additions and 331 deletions

View File

@ -19,17 +19,19 @@ Let's write a custom kernel that computes ``exp`` elementwise:
kernel = mx.fast.metal_kernel( kernel = mx.fast.metal_kernel(
name="myexp", name="myexp",
input_names=["inp"],
output_names=["out"],
source=source, source=source,
) )
outputs = kernel( outputs = kernel(
inputs={"inp": a}, inputs=[a],
template={"T": mx.float32}, template=[("T", mx.float32)],
grid=(a.size, 1, 1), grid=(a.size, 1, 1),
threadgroup=(256, 1, 1), threadgroup=(256, 1, 1),
output_shapes={"out": a.shape}, output_shapes=[a.shape],
output_dtypes={"out": a.dtype}, output_dtypes=[a.dtype],
) )
return outputs["out"] return outputs[0]
a = mx.random.normal(shape=(4, 16)).astype(mx.float16) a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
b = exp_elementwise(a) b = exp_elementwise(a)
@ -40,16 +42,16 @@ Let's write a custom kernel that computes ``exp`` elementwise:
The full function signature will be generated using: The full function signature will be generated using:
* The keys and shapes/dtypes of ``inputs`` * The shapes/dtypes of ``inputs``
In the above, ``a`` is an ``mx.array`` of type ``mx.float16`` and we pass it with the key ``inp`` In the above, ``a`` is an ``mx.array`` of type ``mx.float16`` and we pass it with the key ``inp``
so we will add ``const device float16_t* inp`` to the signature. so we will add ``const device float16_t* inp`` to the signature.
``inp_shape``, ``inp_strides`` and ``inp_ndim`` are also added for convenience if they are present ``inp_shape``, ``inp_strides`` and ``inp_ndim`` are also added for convenience if they are present
in ``source``. in ``source``.
* The keys and values of ``output_shapes`` and ``output_dtypes`` * The list of ``output_dtypes``
In the above, ``out`` is an ``mx.array`` of type ``mx.float16`` In the above, ``out`` is an ``mx.array`` of type ``mx.float16``
so we add ``device float16_t* out``. so we add ``device float16_t* out``.
* Template parameters passed using ``template`` * Template parameters passed using ``template``
In the above, ``template={"T": mx.float32}`` adds a template of ``template <typename T>`` to the function In the above, ``template=[("T", mx.float32)]`` adds a template of ``template <typename T>`` to the function
and instantiates the template with ``custom_kernel_myexp_float<float>``. and instantiates the template with ``custom_kernel_myexp_float<float>``.
Template parameters can be ``mx.core.Dtype``, ``int`` or ``bool``. Template parameters can be ``mx.core.Dtype``, ``int`` or ``bool``.
* Metal attributes used in ``source`` such as ``[[thread_position_in_grid]]`` * Metal attributes used in ``source`` such as ``[[thread_position_in_grid]]``
@ -104,18 +106,20 @@ Let's convert ``myexp`` above to support arbitrarily strided arrays without rely
kernel = mx.fast.metal_kernel( kernel = mx.fast.metal_kernel(
name="myexp_strided", name="myexp_strided",
input_names=["inp"],
output_names=["out"],
source=source source=source
) )
outputs = kernel( outputs = kernel(
inputs={"inp": a}, inputs=[a],
template={"T": mx.float32}, template=[("T", mx.float32)],
grid=(a.size, 1, 1), grid=(a.size, 1, 1),
threadgroup=(256, 1, 1), threadgroup=(256, 1, 1),
output_shapes={"out": a.shape}, output_shapes=[a.shape],
output_dtypes={"out": a.dtype}, output_dtypes=[a.dtype],
ensure_row_contiguous=False, ensure_row_contiguous=False,
) )
return outputs["out"] return outputs[0]
a = mx.random.normal(shape=(4, 16)).astype(mx.float16) a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
# make non-contiguous # make non-contiguous
@ -243,17 +247,19 @@ First we'll implement the forward pass as a fused kernel:
""" """
kernel = mx.fast.metal_kernel( kernel = mx.fast.metal_kernel(
name="grid_sample", name="grid_sample",
input_names=["x", "grid"],
output_names=["out"],
source=source, source=source,
) )
outputs = kernel( outputs = kernel(
inputs={"x": x, "grid": grid}, inputs=[x, grid],
template={"T": x.dtype}, template=[("T", x.dtype)],
output_shapes={"out": out_shape}, output_shapes=[out_shape],
output_dtypes={"out": x.dtype}, output_dtypes=[x.dtype],
grid=(np.prod(out_shape), 1, 1), grid=(np.prod(out_shape), 1, 1),
threadgroup=(256, 1, 1), threadgroup=(256, 1, 1),
) )
return outputs["out"] return outputs[0]
For a reasonably sized input such as: For a reasonably sized input such as:
@ -389,6 +395,8 @@ We can then implement the backwards pass as follows:
""" """
kernel = mx.fast.metal_kernel( kernel = mx.fast.metal_kernel(
name="grid_sample_grad", name="grid_sample_grad",
input_names=["x", "grid", "cotangent"],
output_names=["x_grad", "grid_grad"],
source=source, source=source,
atomic_outputs=True, atomic_outputs=True,
) )
@ -398,15 +406,15 @@ We can then implement the backwards pass as follows:
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
grid_size = B * gN * gM * C_padded grid_size = B * gN * gM * C_padded
outputs = kernel( outputs = kernel(
inputs={"x": x, "grid": grid, "cotangent": cotangent}, inputs=[x, grid, cotangent],
template={"T": x.dtype}, template=[("T", x.dtype)],
output_shapes={"x_grad": x.shape, "grid_grad": grid.shape}, output_shapes=[x.shape, grid.shape],
output_dtypes={"x_grad": x.dtype, "grid_grad": x.dtype}, output_dtypes=[x.dtype, x.dtype],
grid=(grid_size, 1, 1), grid=(grid_size, 1, 1),
threadgroup=(256, 1, 1), threadgroup=(256, 1, 1),
init_value=0, init_value=0,
) )
return outputs["x_grad"], outputs["grid_grad"] return outputs[0], outputs[1]
There's an even larger speed up for the vjp: There's an even larger speed up for the vjp:

View File

@ -515,7 +515,7 @@ array scaled_dot_product_attention(
const array& values, const array& values,
const float scale, const float scale,
const std::optional<array>& mask, const std::optional<array>& mask,
const std::optional<int>& memory_efficient_threshold, const std::optional<int> memory_efficient_threshold,
StreamOrDevice s) { StreamOrDevice s) {
for (const auto& tensor : {queries, keys, values}) { for (const auto& tensor : {queries, keys, values}) {
if (tensor.ndim() != 4) { if (tensor.ndim() != 4) {
@ -916,47 +916,23 @@ array affine_dequantize(
return fallback({w, scales, biases})[0]; return fallback({w, scales, biases})[0];
} }
void validate_output_shapes(
std::map<std::string, std::vector<int>> output_shapes,
std::map<std::string, Dtype> output_dtypes) {
// Make sure output shapes and dtypes have the same keys
bool validated = true;
if (output_shapes.size() == 0) {
throw std::invalid_argument(
"[metal_kernel] Must specify at least one output.");
}
if (output_shapes.size() != output_dtypes.size()) {
validated = false;
} else {
for (const auto& kv : output_shapes) {
if (output_dtypes.find(kv.first) == output_dtypes.end()) {
validated = false;
break;
}
}
}
if (!validated) {
throw std::invalid_argument(
"[metal_kernel] `output_shapes` and `output_dtypes` must have the same keys.");
}
}
void write_signature( void write_signature(
std::string func_name, std::string func_name,
std::string& source, const std::string& source,
std::map<std::string, array>& inputs, const std::vector<std::string>& input_names,
std::map<std::string, std::vector<int>>& output_shapes, const std::vector<array>& inputs,
std::map<std::string, Dtype>& output_dtypes, const std::vector<std::string>& output_names,
std::optional<std::map<std::string, TemplateArg>> template_args, const std::vector<Dtype>& output_dtypes,
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
std::vector<CustomKernelShapeInfo>& shape_infos, std::vector<CustomKernelShapeInfo>& shape_infos,
bool atomic_outputs, bool atomic_outputs,
std::ostringstream& kernel_source) { std::ostringstream& kernel_source) {
// Auto-generate a function signature based on `template_args` // Auto-generate a function signature based on `template_args`
// and the dtype/shape of the arrays passed as `inputs`. // and the dtype/shape of the arrays passed as `inputs`.
if (template_args && template_args.value().size() > 0) { if (!template_args.empty()) {
kernel_source << "template <"; kernel_source << "template <";
int i = 0; int i = 0;
for (const auto& [name, arg] : template_args.value()) { for (const auto& [name, arg] : template_args) {
std::string param_type; std::string param_type;
if (std::holds_alternative<int>(arg)) { if (std::holds_alternative<int>(arg)) {
param_type = "int"; param_type = "int";
@ -1008,7 +984,9 @@ void write_signature(
int index = 0; int index = 0;
constexpr int max_constant_array_size = 8; constexpr int max_constant_array_size = 8;
// Add inputs // Add inputs
for (const auto& [name, arr] : inputs) { for (int i = 0; i < inputs.size(); ++i) {
const auto& name = input_names[i];
const auto& arr = inputs[i];
auto dtype = get_type_string(arr.dtype()); auto dtype = get_type_string(arr.dtype());
bool is_constant = bool is_constant =
arr.is_available() && arr.size() < max_constant_array_size; arr.is_available() && arr.size() < max_constant_array_size;
@ -1042,7 +1020,9 @@ void write_signature(
shape_infos.push_back(shape_info); shape_infos.push_back(shape_info);
} }
// Add outputs // Add outputs
for (const auto& [name, dtype] : output_dtypes) { for (int i = 0; i < output_names.size(); ++i) {
const auto& name = output_names[i];
const auto& dtype = output_dtypes[i];
kernel_source << " device "; kernel_source << " device ";
auto type_string = get_type_string(dtype); auto type_string = get_type_string(dtype);
if (atomic_outputs) { if (atomic_outputs) {
@ -1051,7 +1031,7 @@ void write_signature(
kernel_source << type_string; kernel_source << type_string;
} }
kernel_source << "* " << name << " [[buffer(" << index << ")]]"; kernel_source << "* " << name << " [[buffer(" << index << ")]]";
if (index < inputs.size() + output_shapes.size() - 1 || attrs.size() > 0) { if (index < inputs.size() + output_names.size() - 1 || attrs.size() > 0) {
kernel_source << "," << std::endl; kernel_source << "," << std::endl;
} else { } else {
kernel_source << ") {" << std::endl; kernel_source << ") {" << std::endl;
@ -1073,7 +1053,8 @@ void write_signature(
kernel_source << "}" << std::endl; kernel_source << "}" << std::endl;
} }
std::string write_template(std::map<std::string, TemplateArg>& template_args) { std::string write_template(
const std::vector<std::pair<std::string, TemplateArg>>& template_args) {
std::ostringstream template_def; std::ostringstream template_def;
template_def << "<"; template_def << "<";
int i = 0; int i = 0;
@ -1094,107 +1075,115 @@ std::string write_template(std::map<std::string, TemplateArg>& template_args) {
return template_def.str(); return template_def.str();
} }
std::map<std::string, array> MetalKernel::operator()( MetalKernelFunction metal_kernel(
std::map<std::string, array>& inputs, const std::string& name,
std::map<std::string, std::vector<int>> output_shapes, const std::vector<std::string>& input_names,
std::map<std::string, Dtype> output_dtypes, const std::vector<std::string>& output_names,
std::tuple<int, int, int> grid, const std::string& source,
std::tuple<int, int, int> threadgroup, const std::string& header /* = "" */,
std::optional<std::map<std::string, TemplateArg>> template_args, bool ensure_row_contiguous /* = true */,
std::optional<float> init_value, bool atomic_outputs /* = false */) {
bool verbose, if (output_names.empty()) {
StreamOrDevice s_) {
validate_output_shapes(output_shapes, output_dtypes);
auto s = to_stream(s_);
if (s.device != Device::gpu) {
throw std::invalid_argument( throw std::invalid_argument(
"[metal_kernel] MetalKernel only works on GPU."); "[metal_kernel] Must specify at least one output.");
} }
std::ostringstream func_name; return [=](const std::vector<array>& inputs,
const std::vector<std::vector<int>>& output_shapes,
const std::vector<Dtype>& output_dtypes,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
const std::vector<std::pair<std::string, TemplateArg>>&
template_args = {},
std::optional<float> init_value = std::nullopt,
bool verbose = false,
StreamOrDevice s_ = {}) {
if (inputs.size() != input_names.size()) {
std::ostringstream msg;
msg << "[metal_kernel] Expected `inputs` to have size "
<< input_names.size() << " but got size " << inputs.size() << "."
<< std::endl;
throw std::invalid_argument(msg.str());
}
if (output_shapes.size() != output_names.size()) {
std::ostringstream msg;
msg << "[metal_kernel] Expected `output_shapes` to have size "
<< output_names.size() << " but got size " << output_shapes.size()
<< "." << std::endl;
throw std::invalid_argument(msg.str());
}
if (output_dtypes.size() != output_names.size()) {
std::ostringstream msg;
msg << "[metal_kernel] Expected `output_dtypes` to have size "
<< output_names.size() << " but got size " << output_dtypes.size()
<< "." << std::endl;
throw std::invalid_argument(msg.str());
}
std::string template_def = ""; auto s = to_stream(s_);
bool needs_template = template_args && template_args.value().size() > 0; if (s.device != Device::gpu) {
std::string hash_key = ""; throw std::invalid_argument("[metal_kernel] Only supports the GPU.");
if (needs_template) { }
std::regex disallowed_chars("\\<|\\>|(, )");
template_def = write_template(template_args.value());
hash_key = std::regex_replace(template_def, disallowed_chars, "_");
hash_key.pop_back();
}
func_name << "custom_kernel_" << name_ << hash_key; std::ostringstream func_name;
std::string kernel_name = func_name.str();
std::ostringstream kernel_source; std::string template_def = "";
kernel_source << header_ << std::endl; std::string hash_key = "";
if (!template_args.empty()) {
std::regex disallowed_chars("\\<|\\>|(, )");
template_def = write_template(template_args);
hash_key = std::regex_replace(template_def, disallowed_chars, "_");
hash_key.pop_back();
}
std::vector<CustomKernelShapeInfo> shape_infos; func_name << "custom_kernel_" << name << hash_key;
write_signature( std::string kernel_name = func_name.str();
func_name.str(),
source_,
inputs,
output_shapes,
output_dtypes,
template_args,
shape_infos,
atomic_outputs_,
kernel_source);
if (needs_template) { std::ostringstream kernel_source;
template_def = func_name.str() + template_def; kernel_source << header << std::endl;
kernel_source << std::endl
<< "template [[host_name(\"" << kernel_name
<< "\")]] [[kernel]] decltype(" << template_def << ") "
<< template_def << ";" << std::endl;
}
if (verbose) { std::vector<CustomKernelShapeInfo> shape_infos;
std::cout << "Generated source code for `" << name_ << "`:" << std::endl write_signature(
<< "```" << std::endl func_name.str(),
<< kernel_source.str() << std::endl source,
<< "```" << std::endl; input_names,
} inputs,
output_names,
output_dtypes,
template_args,
shape_infos,
atomic_outputs,
kernel_source);
std::vector<array> in_arrs; if (!template_args.empty()) {
for (const auto& kv : inputs) { template_def = func_name.str() + template_def;
in_arrs.push_back(kv.second); kernel_source << std::endl
} << "template [[host_name(\"" << kernel_name
<< "\")]] [[kernel]] decltype(" << template_def << ") "
<< template_def << ";" << std::endl;
}
std::vector<std::string> out_keys; if (verbose) {
std::vector<std::vector<int>> out_shapes; std::cout << "Generated source code for `" << name << "`:" << std::endl
for (const auto& [name, shape] : output_shapes) { << "```" << std::endl
out_keys.push_back(name); << kernel_source.str() << std::endl
out_shapes.push_back(shape); << "```" << std::endl;
} }
std::vector<Dtype> out_dtypes; return array::make_arrays(
for (const auto& kv : output_dtypes) { output_shapes,
out_dtypes.push_back(kv.second); output_dtypes,
} std::make_shared<CustomKernel>(
s,
std::map<std::string, array> outputs; kernel_name,
auto outputs_vec = array::make_arrays( kernel_source.str(),
out_shapes, grid,
out_dtypes, threadgroup,
std::make_shared<CustomKernel>( shape_infos,
s, ensure_row_contiguous,
kernel_name, init_value),
kernel_source.str(), inputs);
grid, };
threadgroup,
shape_infos,
ensure_row_contiguous_,
init_value),
in_arrs);
int i = 0;
for (const auto& key : out_keys) {
outputs.insert({key, outputs_vec[i]});
i++;
}
return outputs;
} }
} // namespace mlx::core::fast } // namespace mlx::core::fast

View File

@ -2,7 +2,6 @@
#pragma once #pragma once
#include <map>
#include <optional> #include <optional>
#include "mlx/utils.h" #include "mlx/utils.h"
@ -39,7 +38,7 @@ array scaled_dot_product_attention(
const array& values, const array& values,
const float scale, const float scale,
const std::optional<array>& mask = std::nullopt, const std::optional<array>& mask = std::nullopt,
const std::optional<int>& memory_efficient_threshold = std::nullopt, const std::optional<int> memory_efficient_threshold = std::nullopt,
StreamOrDevice s = {}); StreamOrDevice s = {});
std::tuple<array, array, array> affine_quantize( std::tuple<array, array, array> affine_quantize(
@ -66,37 +65,25 @@ array affine_dequantize(
typedef std::variant<int, bool, Dtype> TemplateArg; typedef std::variant<int, bool, Dtype> TemplateArg;
class MetalKernel { typedef std::function<std::vector<array>(
public: const std::vector<array>&,
MetalKernel( const std::vector<std::vector<int>>&,
const std::string& name, const std::vector<Dtype>&,
const std::string& source, std::tuple<int, int, int>,
const std::string& header = "", std::tuple<int, int, int>,
bool ensure_row_contiguous = true, std::vector<std::pair<std::string, TemplateArg>>,
bool atomic_outputs = false) std::optional<float>,
: name_(name), bool,
source_(source), StreamOrDevice)>
header_(header), MetalKernelFunction;
ensure_row_contiguous_(ensure_row_contiguous),
atomic_outputs_(atomic_outputs) {}
std::map<std::string, array> operator()( MetalKernelFunction metal_kernel(
std::map<std::string, array>& inputs, const std::string& name,
std::map<std::string, std::vector<int>> output_shapes, const std::vector<std::string>& input_names,
std::map<std::string, Dtype> output_dtypes, const std::vector<std::string>& output_names,
std::tuple<int, int, int> grid, const std::string& source,
std::tuple<int, int, int> threadgroup, const std::string& header = "",
std::optional<std::map<std::string, TemplateArg>> template_args = bool ensure_row_contiguous = true,
std::nullopt, bool atomic_outputs = false);
std::optional<float> init_value = std::nullopt,
bool verbose = false,
StreamOrDevice s = {});
private:
std::string name_;
std::string source_;
std::string header_;
bool ensure_row_contiguous_;
bool atomic_outputs_;
};
} // namespace mlx::core::fast } // namespace mlx::core::fast

View File

@ -1425,8 +1425,8 @@ array where(
array nan_to_num( array nan_to_num(
const array& a, const array& a,
float nan /* = 0.0f */, float nan /* = 0.0f */,
const std::optional<float>& posinf_ /* = std::nullopt */, const std::optional<float> posinf_ /* = std::nullopt */,
const std::optional<float>& neginf_ /* = std::nullopt */, const std::optional<float> neginf_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
Dtype dtype = a.dtype(); Dtype dtype = a.dtype();
if (!issubdtype(dtype, inexact)) { if (!issubdtype(dtype, inexact)) {

View File

@ -416,8 +416,8 @@ array where(
array nan_to_num( array nan_to_num(
const array& a, const array& a,
float nan = 0.0f, float nan = 0.0f,
const std::optional<float>& posinf = std::nullopt, const std::optional<float> posinf = std::nullopt,
const std::optional<float>& neginf = std::nullopt, const std::optional<float> neginf = std::nullopt,
StreamOrDevice s = {}); StreamOrDevice s = {});
/** True if all elements in the array are true (or non-zero). **/ /** True if all elements in the array are true (or non-zero). **/

View File

@ -1,8 +1,8 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <nanobind/nanobind.h> #include <nanobind/nanobind.h>
#include <nanobind/stl/map.h>
#include <nanobind/stl/optional.h> #include <nanobind/stl/optional.h>
#include <nanobind/stl/pair.h>
#include <nanobind/stl/string.h> #include <nanobind/stl/string.h>
#include <nanobind/stl/tuple.h> #include <nanobind/stl/tuple.h>
#include <nanobind/stl/variant.h> #include <nanobind/stl/variant.h>
@ -193,39 +193,130 @@ void init_fast(nb::module_& parent_module) {
array: The quantized version of ``w`` array: The quantized version of ``w``
)pbdoc"); )pbdoc");
nb::class_<fast::MetalKernel>( m.def(
m,
"metal_kernel", "metal_kernel",
[](const std::string& name,
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
const std::string& source,
const std::string& header,
bool ensure_row_contiguous,
bool atomic_outputs) {
auto kernel = fast::metal_kernel(
name,
input_names,
output_names,
source,
header,
ensure_row_contiguous,
atomic_outputs);
return nb::cpp_function(
[kernel = std::move(kernel)](
const std::vector<ScalarOrArray>& inputs_,
const std::vector<std::vector<int>>& output_shapes,
const std::vector<Dtype>& output_dtypes,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
const std::optional<
std::vector<std::pair<std::string, nb::object>>>&
template_args_ = std::nullopt,
std::optional<float> init_value = std::nullopt,
bool verbose = false,
StreamOrDevice s = {}) {
std::vector<array> inputs;
for (const auto& value : inputs_) {
inputs.push_back(to_array(value, std::nullopt));
}
std::vector<std::pair<std::string, fast::TemplateArg>>
template_args;
if (template_args_) {
for (const auto& [name, value] : template_args_.value()) {
// Handle bool, int and dtype template args
if (nb::isinstance<bool>(value)) {
bool bool_val = nb::cast<bool>(value);
template_args.emplace_back(name, bool_val);
} else if (nb::isinstance<int>(value)) {
int int_val = nb::cast<int>(value);
template_args.emplace_back(name, int_val);
} else if (nb::isinstance<Dtype>(value)) {
Dtype dtype = nb::cast<Dtype>(value);
template_args.emplace_back(name, dtype);
} else {
throw std::invalid_argument(
"[[metal_kernel]] Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`.");
}
}
}
return kernel(
inputs,
output_shapes,
output_dtypes,
grid,
threadgroup,
template_args,
init_value,
verbose,
s);
},
nb::kw_only(),
"inputs"_a,
"output_shapes"_a,
"output_dtypes"_a,
"grid"_a,
"threadgroup"_a,
"template"_a = nb::none(),
"init_value"_a = nb::none(),
"verbose"_a = false,
"stream"_a = nb::none(),
nb::sig(
"def __call__(self, *, inputs: List[Union[scalar, array]], output_shapes: List[Sequence[int]], output_dtypes: List[Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[List[Tuple[str, Union[bool, int, Dtype]]]] = None, init_value: Optional[float] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"),
R"pbdoc(
Run the kernel.
Args:
inputs (List[array]): The inputs passed to the Metal kernel.
output_shapes (List[Sequence[int]]): The list of shapes for each output in ``output_names``.
output_dtypes (List[Dtype]): The list of data types for each output in ``output_names``.
grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.
threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.
template (List[Tuple[str, Union[bool, int, Dtype]]], optional): Template arguments.
These will be added as template arguments to the kernel definition. Default: ``None``.
init_value (float, optional): Optional value to use to initialize all of the output arrays.
By default, output arrays are uninitialized. Default: ``None``.
verbose (bool, optional): Whether to print the full generated source code of the kernel
when it is run. Default: ``False``.
stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``.
Returns:
List[array]: The list of output arrays.
)pbdoc");
},
"name"_a,
"input_names"_a,
"output_names"_a,
"source"_a,
"header"_a = "",
"ensure_row_contiguous"_a = true,
"atomic_outputs"_a = false,
R"pbdoc( R"pbdoc(
A jit-compiled custom Metal kernel defined from a source string. A jit-compiled custom Metal kernel defined from a source string.
)pbdoc")
.def(
nb::init<
const std::string&,
const std::string&,
const std::string&,
bool,
bool>(),
"name"_a,
"source"_a,
"header"_a = "",
"ensure_row_contiguous"_a = true,
"atomic_outputs"_a = false,
R"pbdoc(
Initialize a metal_kernel.
Args: Args:
name (str): Name for the kernel. name (str): Name for the kernel.
source (str): Source code. This is the body of a function in Metal, input_names (List[str]): The parameter names of the inputs in the
the function signature will be generated for you. The names of the inputs/outputs function signature.
are determined by the ``inputs`` and ``output_shapes``/``output_dtypes`` output_names (List[str]): The parameter names of the outputs in the
used when the kernel is called. function signature.
header (str): Header source code to include before the main function. source (str): Source code. This is the body of a function in Metal,
Useful for helper functions or includes that should live outside of the main function body. the function signature will be automatically generated.
ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous header (str): Header source code to include before the main function.
before the kernel runs. Default: ``True``. Useful for helper functions or includes that should live outside of
atomic_outputs (bool): Whether to use atomic outputs in the function signature the main function body.
e.g. ``device atomic<float>``. Default: ``False``. ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous
before the kernel runs. Default: ``True``.
atomic_outputs (bool): Whether to use atomic outputs in the function signature
e.g. ``device atomic<float>``. Default: ``False``.
Returns: Returns:
Callable ``metal_kernel``. Callable ``metal_kernel``.
@ -242,103 +333,23 @@ void init_fast(nb::module_& parent_module) {
kernel = mx.fast.metal_kernel( kernel = mx.fast.metal_kernel(
name="myexp", name="myexp",
input_names=["inp"],
output_names=["out"],
source=source source=source
) )
outputs = kernel( outputs = kernel(
inputs={"inp": a}, inputs=[a],
template={"T": mx.float32}, template=[("T", mx.float32)],
grid=(a.size, 1, 1), grid=(a.size, 1, 1),
threadgroup=(256, 1, 1), threadgroup=(256, 1, 1),
output_shapes={"out": a.shape}, output_shapes=[a.shape],
output_dtypes={"out": a.dtype}, output_dtypes=[a.dtype],
verbose=True, verbose=True,
) )
return outputs["out"] return outputs[0]
a = mx.random.normal(shape=(4, 16)).astype(mx.float16) a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
b = exp_elementwise(a) b = exp_elementwise(a)
assert mx.allclose(b, mx.exp(a)) assert mx.allclose(b, mx.exp(a))
)pbdoc") )pbdoc");
.def(
"__call__",
[](fast::MetalKernel& kernel,
std::map<std::string, ScalarOrArray>& inputs_,
std::map<std::string, std::vector<int>>& output_shapes,
std::map<std::string, Dtype>& output_dtypes,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
std::optional<std::map<std::string, nb::handle>> template_args_,
std::optional<float> init_value,
bool verbose,
StreamOrDevice s) {
std::map<std::string, array> inputs;
for (const auto& [name, value] : inputs_) {
auto arr = to_array(value, std::nullopt);
inputs.insert({name, arr});
}
std::map<std::string, fast::TemplateArg> template_args;
if (template_args_) {
for (const auto& [name, value] : template_args_.value()) {
// Handle bool, int and dtype template args
if (nb::isinstance<bool>(value)) {
bool bool_val = nb::cast<bool>(value);
template_args.insert({name, bool_val});
} else if (nb::isinstance<int>(value)) {
int int_val = nb::cast<int>(value);
template_args.insert({name, int_val});
} else if (nb::isinstance<Dtype>(value)) {
Dtype dtype = nb::cast<Dtype>(value);
template_args.insert({name, dtype});
} else {
throw std::invalid_argument(
"[[metal_kernel]] Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`.");
}
}
}
return kernel(
inputs,
output_shapes,
output_dtypes,
grid,
threadgroup,
template_args,
init_value,
verbose,
s);
},
nb::kw_only(),
"inputs"_a,
"output_shapes"_a,
"output_dtypes"_a,
"grid"_a,
"threadgroup"_a,
"template"_a = nb::none(),
"init_value"_a = nb::none(),
"verbose"_a = false,
"stream"_a = nb::none(),
nb::sig(
"def __call__(self, *, inputs: Mapping[str, Union[scalar, array]], output_shapes: Mapping[str, Sequence[int]], output_dtypes: Mapping[str, Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[Mapping[str, Union[bool, int, Dtype]]] = None, init_value: Optional[float] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"),
R"pbdoc(
Run the kernel.
Args:
inputs (Mapping[str, array]): Inputs. These will be added to the function signature and passed to the Metal kernel.
The keys will be the names of the arguments to the kernel.
output_shapes (Mapping[str, Sequence[int]]): Output shapes. A dict mapping
output variable names to shapes. These will be added to the function signature.
output_dtypes (Mapping[str, Dtype]): Output dtypes. A dict mapping output variable
names to dtypes. Must have the same keys as ``output_shapes``.
grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.
threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.
template (Mapping[str, Union[bool, int, Dtype]], optional): Template arguments.
These will be added as template arguments to the kernel definition. Default: ``None``.
init_value (float, optional): Optional value to use to initialize all of the output arrays.
By default, output arrays are uninitialized. Default: ``None``.
verbose (bool, optional): Whether to print the full generated source code of the kernel
when it is run. Default: ``False``.
stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``.
Returns:
dict[str, array]: Dictionary of output arrays based on ``output_shapes``/``output_dtypes``.
)pbdoc");
} }

View File

@ -562,20 +562,22 @@ class TestFast(mlx_tests.MLXTestCase):
a = mx.random.normal(shape=(2, 2)) a = mx.random.normal(shape=(2, 2))
kernel = mx.fast.metal_kernel( kernel = mx.fast.metal_kernel(
name="basic", name="basic",
input_names=["a"],
output_names=["out1"],
source=""" source="""
uint elem = thread_position_in_grid.x; uint elem = thread_position_in_grid.x;
out1[elem] = a[elem]; out1[elem] = a[elem];
""", """,
) )
out = kernel( out = kernel(
inputs={"a": a}, inputs=[a],
grid=(4, 1, 1), grid=(4, 1, 1),
threadgroup=(2, 1, 1), threadgroup=(2, 1, 1),
output_shapes={"out1": (2, 2)}, output_shapes=[(2, 2)],
output_dtypes={"out1": mx.float32}, output_dtypes=[mx.float32],
stream=mx.gpu, stream=mx.gpu,
) )
self.assertTrue(mx.allclose(out["out1"], a)) self.assertTrue(mx.allclose(out[0], a))
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available") @unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_custom_kernel_args(self): def test_custom_kernel_args(self):
@ -585,6 +587,8 @@ class TestFast(mlx_tests.MLXTestCase):
kernel = mx.fast.metal_kernel( kernel = mx.fast.metal_kernel(
name="arg_test", name="arg_test",
input_names=["a", "b", "c", "d"],
output_names=["out1", "out2"],
source=""" source="""
uint elem = thread_position_in_grid.x; uint elem = thread_position_in_grid.x;
T tmp = a[0]; T tmp = a[0];
@ -597,26 +601,26 @@ class TestFast(mlx_tests.MLXTestCase):
""", """,
) )
out = kernel( out = kernel(
inputs={ inputs=[
"a": a, a,
"b": mx.array([3, 4, 5]), mx.array([3, 4, 5]),
"c": c, c,
"d": 7.3, 7.3,
}, ],
template={ template=[
"e": True, ("e", True),
"f": 3, ("f", 3),
"T": mx.float16, ("T", mx.float16),
}, ],
grid=(6, 1, 1), grid=(6, 1, 1),
threadgroup=(2, 1, 1), threadgroup=(2, 1, 1),
output_shapes={"out1": (2, 2), "out2": (3, 2)}, output_shapes=[(2, 2), (3, 2)],
output_dtypes={"out1": mx.float32, "out2": mx.int32}, output_dtypes=[mx.float32, mx.int32],
stream=mx.gpu, stream=mx.gpu,
) )
self.assertTrue(mx.allclose(out["out1"], mx.full((2, 2), 14.0484))) self.assertTrue(mx.allclose(out[0], mx.full((2, 2), 14.0484)))
self.assertTrue(mx.allclose(out["out2"], mx.full((3, 2), -2, dtype=mx.int32))) self.assertTrue(mx.allclose(out[1], mx.full((3, 2), -2, dtype=mx.int32)))
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available") @unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_custom_kernel_strides(self): def test_custom_kernel_strides(self):
@ -640,19 +644,21 @@ class TestFast(mlx_tests.MLXTestCase):
for contig in [True, False]: for contig in [True, False]:
kernel = mx.fast.metal_kernel( kernel = mx.fast.metal_kernel(
name="myexp" + str(contig), name="myexp" + str(contig),
input_names=["inp"],
output_names=["out"],
source=source_contig if contig else source, source=source_contig if contig else source,
ensure_row_contiguous=contig, ensure_row_contiguous=contig,
) )
outputs = kernel( outputs = kernel(
inputs={"inp": a}, inputs=[a],
template={"T": mx.float32}, template=[("T", mx.float32)],
grid=(a.size, 1, 1), grid=(a.size, 1, 1),
threadgroup=(256, 1, 1), threadgroup=(256, 1, 1),
output_shapes={"out": a.shape}, output_shapes=[a.shape],
output_dtypes={"out": a.dtype}, output_dtypes=[a.dtype],
stream=mx.gpu, stream=mx.gpu,
) )
self.assertTrue(mx.allclose(mx.exp(a) * 32, outputs["out"])) self.assertTrue(mx.allclose(mx.exp(a) * 32, outputs[0]))
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available") @unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_custom_kernel_helper(self): def test_custom_kernel_helper(self):
@ -660,6 +666,8 @@ class TestFast(mlx_tests.MLXTestCase):
a = mx.random.normal(shape=(2, 2)) a = mx.random.normal(shape=(2, 2))
kernel = mx.fast.metal_kernel( kernel = mx.fast.metal_kernel(
name="helper", name="helper",
input_names=["a"],
output_names=["out1"],
header=""" header="""
template <typename T> template <typename T>
T do_exp(T x) { T do_exp(T x) {
@ -672,14 +680,14 @@ class TestFast(mlx_tests.MLXTestCase):
""", """,
) )
out = kernel( out = kernel(
inputs={"a": a}, inputs=[a],
grid=(4, 1, 1), grid=(4, 1, 1),
threadgroup=(2, 1, 1), threadgroup=(2, 1, 1),
output_shapes={"out1": (2, 2)}, output_shapes=[(2, 2)],
output_dtypes={"out1": mx.float32}, output_dtypes=[mx.float32],
stream=mx.gpu, stream=mx.gpu,
) )
self.assertTrue(mx.allclose(out["out1"], mx.exp(a))) self.assertTrue(mx.allclose(out[0], mx.exp(a)))
if __name__ == "__main__": if __name__ == "__main__":