mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Simplifications for MLX C (#1396)
* simplifications for MLX C * use vectors instead of map * update examples
This commit is contained in:
parent
7cca1727af
commit
ba3e913c7a
@ -19,17 +19,19 @@ Let's write a custom kernel that computes ``exp`` elementwise:
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp",
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
source=source,
|
||||
)
|
||||
outputs = kernel(
|
||||
inputs={"inp": a},
|
||||
template={"T": mx.float32},
|
||||
inputs=[a],
|
||||
template=[("T", mx.float32)],
|
||||
grid=(a.size, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
output_shapes={"out": a.shape},
|
||||
output_dtypes={"out": a.dtype},
|
||||
output_shapes=[a.shape],
|
||||
output_dtypes=[a.dtype],
|
||||
)
|
||||
return outputs["out"]
|
||||
return outputs[0]
|
||||
|
||||
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
|
||||
b = exp_elementwise(a)
|
||||
@ -40,16 +42,16 @@ Let's write a custom kernel that computes ``exp`` elementwise:
|
||||
|
||||
The full function signature will be generated using:
|
||||
|
||||
* The keys and shapes/dtypes of ``inputs``
|
||||
* The shapes/dtypes of ``inputs``
|
||||
In the above, ``a`` is an ``mx.array`` of type ``mx.float16`` and we pass it with the key ``inp``
|
||||
so we will add ``const device float16_t* inp`` to the signature.
|
||||
``inp_shape``, ``inp_strides`` and ``inp_ndim`` are also added for convenience if they are present
|
||||
in ``source``.
|
||||
* The keys and values of ``output_shapes`` and ``output_dtypes``
|
||||
* The list of ``output_dtypes``
|
||||
In the above, ``out`` is an ``mx.array`` of type ``mx.float16``
|
||||
so we add ``device float16_t* out``.
|
||||
* Template parameters passed using ``template``
|
||||
In the above, ``template={"T": mx.float32}`` adds a template of ``template <typename T>`` to the function
|
||||
In the above, ``template=[("T", mx.float32)]`` adds a template of ``template <typename T>`` to the function
|
||||
and instantiates the template with ``custom_kernel_myexp_float<float>``.
|
||||
Template parameters can be ``mx.core.Dtype``, ``int`` or ``bool``.
|
||||
* Metal attributes used in ``source`` such as ``[[thread_position_in_grid]]``
|
||||
@ -104,18 +106,20 @@ Let's convert ``myexp`` above to support arbitrarily strided arrays without rely
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp_strided",
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
source=source
|
||||
)
|
||||
outputs = kernel(
|
||||
inputs={"inp": a},
|
||||
template={"T": mx.float32},
|
||||
inputs=[a],
|
||||
template=[("T", mx.float32)],
|
||||
grid=(a.size, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
output_shapes={"out": a.shape},
|
||||
output_dtypes={"out": a.dtype},
|
||||
output_shapes=[a.shape],
|
||||
output_dtypes=[a.dtype],
|
||||
ensure_row_contiguous=False,
|
||||
)
|
||||
return outputs["out"]
|
||||
return outputs[0]
|
||||
|
||||
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
|
||||
# make non-contiguous
|
||||
@ -243,17 +247,19 @@ First we'll implement the forward pass as a fused kernel:
|
||||
"""
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="grid_sample",
|
||||
input_names=["x", "grid"],
|
||||
output_names=["out"],
|
||||
source=source,
|
||||
)
|
||||
outputs = kernel(
|
||||
inputs={"x": x, "grid": grid},
|
||||
template={"T": x.dtype},
|
||||
output_shapes={"out": out_shape},
|
||||
output_dtypes={"out": x.dtype},
|
||||
inputs=[x, grid],
|
||||
template=[("T", x.dtype)],
|
||||
output_shapes=[out_shape],
|
||||
output_dtypes=[x.dtype],
|
||||
grid=(np.prod(out_shape), 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
)
|
||||
return outputs["out"]
|
||||
return outputs[0]
|
||||
|
||||
For a reasonably sized input such as:
|
||||
|
||||
@ -389,6 +395,8 @@ We can then implement the backwards pass as follows:
|
||||
"""
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="grid_sample_grad",
|
||||
input_names=["x", "grid", "cotangent"],
|
||||
output_names=["x_grad", "grid_grad"],
|
||||
source=source,
|
||||
atomic_outputs=True,
|
||||
)
|
||||
@ -398,15 +406,15 @@ We can then implement the backwards pass as follows:
|
||||
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
|
||||
grid_size = B * gN * gM * C_padded
|
||||
outputs = kernel(
|
||||
inputs={"x": x, "grid": grid, "cotangent": cotangent},
|
||||
template={"T": x.dtype},
|
||||
output_shapes={"x_grad": x.shape, "grid_grad": grid.shape},
|
||||
output_dtypes={"x_grad": x.dtype, "grid_grad": x.dtype},
|
||||
inputs=[x, grid, cotangent],
|
||||
template=[("T", x.dtype)],
|
||||
output_shapes=[x.shape, grid.shape],
|
||||
output_dtypes=[x.dtype, x.dtype],
|
||||
grid=(grid_size, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
init_value=0,
|
||||
)
|
||||
return outputs["x_grad"], outputs["grid_grad"]
|
||||
return outputs[0], outputs[1]
|
||||
|
||||
There's an even larger speed up for the vjp:
|
||||
|
||||
|
241
mlx/fast.cpp
241
mlx/fast.cpp
@ -515,7 +515,7 @@ array scaled_dot_product_attention(
|
||||
const array& values,
|
||||
const float scale,
|
||||
const std::optional<array>& mask,
|
||||
const std::optional<int>& memory_efficient_threshold,
|
||||
const std::optional<int> memory_efficient_threshold,
|
||||
StreamOrDevice s) {
|
||||
for (const auto& tensor : {queries, keys, values}) {
|
||||
if (tensor.ndim() != 4) {
|
||||
@ -916,47 +916,23 @@ array affine_dequantize(
|
||||
return fallback({w, scales, biases})[0];
|
||||
}
|
||||
|
||||
void validate_output_shapes(
|
||||
std::map<std::string, std::vector<int>> output_shapes,
|
||||
std::map<std::string, Dtype> output_dtypes) {
|
||||
// Make sure output shapes and dtypes have the same keys
|
||||
bool validated = true;
|
||||
if (output_shapes.size() == 0) {
|
||||
throw std::invalid_argument(
|
||||
"[metal_kernel] Must specify at least one output.");
|
||||
}
|
||||
if (output_shapes.size() != output_dtypes.size()) {
|
||||
validated = false;
|
||||
} else {
|
||||
for (const auto& kv : output_shapes) {
|
||||
if (output_dtypes.find(kv.first) == output_dtypes.end()) {
|
||||
validated = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!validated) {
|
||||
throw std::invalid_argument(
|
||||
"[metal_kernel] `output_shapes` and `output_dtypes` must have the same keys.");
|
||||
}
|
||||
}
|
||||
|
||||
void write_signature(
|
||||
std::string func_name,
|
||||
std::string& source,
|
||||
std::map<std::string, array>& inputs,
|
||||
std::map<std::string, std::vector<int>>& output_shapes,
|
||||
std::map<std::string, Dtype>& output_dtypes,
|
||||
std::optional<std::map<std::string, TemplateArg>> template_args,
|
||||
const std::string& source,
|
||||
const std::vector<std::string>& input_names,
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::vector<Dtype>& output_dtypes,
|
||||
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
|
||||
std::vector<CustomKernelShapeInfo>& shape_infos,
|
||||
bool atomic_outputs,
|
||||
std::ostringstream& kernel_source) {
|
||||
// Auto-generate a function signature based on `template_args`
|
||||
// and the dtype/shape of the arrays passed as `inputs`.
|
||||
if (template_args && template_args.value().size() > 0) {
|
||||
if (!template_args.empty()) {
|
||||
kernel_source << "template <";
|
||||
int i = 0;
|
||||
for (const auto& [name, arg] : template_args.value()) {
|
||||
for (const auto& [name, arg] : template_args) {
|
||||
std::string param_type;
|
||||
if (std::holds_alternative<int>(arg)) {
|
||||
param_type = "int";
|
||||
@ -1008,7 +984,9 @@ void write_signature(
|
||||
int index = 0;
|
||||
constexpr int max_constant_array_size = 8;
|
||||
// Add inputs
|
||||
for (const auto& [name, arr] : inputs) {
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
const auto& name = input_names[i];
|
||||
const auto& arr = inputs[i];
|
||||
auto dtype = get_type_string(arr.dtype());
|
||||
bool is_constant =
|
||||
arr.is_available() && arr.size() < max_constant_array_size;
|
||||
@ -1042,7 +1020,9 @@ void write_signature(
|
||||
shape_infos.push_back(shape_info);
|
||||
}
|
||||
// Add outputs
|
||||
for (const auto& [name, dtype] : output_dtypes) {
|
||||
for (int i = 0; i < output_names.size(); ++i) {
|
||||
const auto& name = output_names[i];
|
||||
const auto& dtype = output_dtypes[i];
|
||||
kernel_source << " device ";
|
||||
auto type_string = get_type_string(dtype);
|
||||
if (atomic_outputs) {
|
||||
@ -1051,7 +1031,7 @@ void write_signature(
|
||||
kernel_source << type_string;
|
||||
}
|
||||
kernel_source << "* " << name << " [[buffer(" << index << ")]]";
|
||||
if (index < inputs.size() + output_shapes.size() - 1 || attrs.size() > 0) {
|
||||
if (index < inputs.size() + output_names.size() - 1 || attrs.size() > 0) {
|
||||
kernel_source << "," << std::endl;
|
||||
} else {
|
||||
kernel_source << ") {" << std::endl;
|
||||
@ -1073,7 +1053,8 @@ void write_signature(
|
||||
kernel_source << "}" << std::endl;
|
||||
}
|
||||
|
||||
std::string write_template(std::map<std::string, TemplateArg>& template_args) {
|
||||
std::string write_template(
|
||||
const std::vector<std::pair<std::string, TemplateArg>>& template_args) {
|
||||
std::ostringstream template_def;
|
||||
template_def << "<";
|
||||
int i = 0;
|
||||
@ -1094,107 +1075,115 @@ std::string write_template(std::map<std::string, TemplateArg>& template_args) {
|
||||
return template_def.str();
|
||||
}
|
||||
|
||||
std::map<std::string, array> MetalKernel::operator()(
|
||||
std::map<std::string, array>& inputs,
|
||||
std::map<std::string, std::vector<int>> output_shapes,
|
||||
std::map<std::string, Dtype> output_dtypes,
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
std::optional<std::map<std::string, TemplateArg>> template_args,
|
||||
std::optional<float> init_value,
|
||||
bool verbose,
|
||||
StreamOrDevice s_) {
|
||||
validate_output_shapes(output_shapes, output_dtypes);
|
||||
|
||||
auto s = to_stream(s_);
|
||||
if (s.device != Device::gpu) {
|
||||
MetalKernelFunction metal_kernel(
|
||||
const std::string& name,
|
||||
const std::vector<std::string>& input_names,
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::string& source,
|
||||
const std::string& header /* = "" */,
|
||||
bool ensure_row_contiguous /* = true */,
|
||||
bool atomic_outputs /* = false */) {
|
||||
if (output_names.empty()) {
|
||||
throw std::invalid_argument(
|
||||
"[metal_kernel] MetalKernel only works on GPU.");
|
||||
"[metal_kernel] Must specify at least one output.");
|
||||
}
|
||||
|
||||
std::ostringstream func_name;
|
||||
return [=](const std::vector<array>& inputs,
|
||||
const std::vector<std::vector<int>>& output_shapes,
|
||||
const std::vector<Dtype>& output_dtypes,
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
const std::vector<std::pair<std::string, TemplateArg>>&
|
||||
template_args = {},
|
||||
std::optional<float> init_value = std::nullopt,
|
||||
bool verbose = false,
|
||||
StreamOrDevice s_ = {}) {
|
||||
if (inputs.size() != input_names.size()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal_kernel] Expected `inputs` to have size "
|
||||
<< input_names.size() << " but got size " << inputs.size() << "."
|
||||
<< std::endl;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (output_shapes.size() != output_names.size()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal_kernel] Expected `output_shapes` to have size "
|
||||
<< output_names.size() << " but got size " << output_shapes.size()
|
||||
<< "." << std::endl;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (output_dtypes.size() != output_names.size()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal_kernel] Expected `output_dtypes` to have size "
|
||||
<< output_names.size() << " but got size " << output_dtypes.size()
|
||||
<< "." << std::endl;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
std::string template_def = "";
|
||||
bool needs_template = template_args && template_args.value().size() > 0;
|
||||
std::string hash_key = "";
|
||||
if (needs_template) {
|
||||
std::regex disallowed_chars("\\<|\\>|(, )");
|
||||
template_def = write_template(template_args.value());
|
||||
hash_key = std::regex_replace(template_def, disallowed_chars, "_");
|
||||
hash_key.pop_back();
|
||||
}
|
||||
auto s = to_stream(s_);
|
||||
if (s.device != Device::gpu) {
|
||||
throw std::invalid_argument("[metal_kernel] Only supports the GPU.");
|
||||
}
|
||||
|
||||
func_name << "custom_kernel_" << name_ << hash_key;
|
||||
std::string kernel_name = func_name.str();
|
||||
std::ostringstream func_name;
|
||||
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << header_ << std::endl;
|
||||
std::string template_def = "";
|
||||
std::string hash_key = "";
|
||||
if (!template_args.empty()) {
|
||||
std::regex disallowed_chars("\\<|\\>|(, )");
|
||||
template_def = write_template(template_args);
|
||||
hash_key = std::regex_replace(template_def, disallowed_chars, "_");
|
||||
hash_key.pop_back();
|
||||
}
|
||||
|
||||
std::vector<CustomKernelShapeInfo> shape_infos;
|
||||
write_signature(
|
||||
func_name.str(),
|
||||
source_,
|
||||
inputs,
|
||||
output_shapes,
|
||||
output_dtypes,
|
||||
template_args,
|
||||
shape_infos,
|
||||
atomic_outputs_,
|
||||
kernel_source);
|
||||
func_name << "custom_kernel_" << name << hash_key;
|
||||
std::string kernel_name = func_name.str();
|
||||
|
||||
if (needs_template) {
|
||||
template_def = func_name.str() + template_def;
|
||||
kernel_source << std::endl
|
||||
<< "template [[host_name(\"" << kernel_name
|
||||
<< "\")]] [[kernel]] decltype(" << template_def << ") "
|
||||
<< template_def << ";" << std::endl;
|
||||
}
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << header << std::endl;
|
||||
|
||||
if (verbose) {
|
||||
std::cout << "Generated source code for `" << name_ << "`:" << std::endl
|
||||
<< "```" << std::endl
|
||||
<< kernel_source.str() << std::endl
|
||||
<< "```" << std::endl;
|
||||
}
|
||||
std::vector<CustomKernelShapeInfo> shape_infos;
|
||||
write_signature(
|
||||
func_name.str(),
|
||||
source,
|
||||
input_names,
|
||||
inputs,
|
||||
output_names,
|
||||
output_dtypes,
|
||||
template_args,
|
||||
shape_infos,
|
||||
atomic_outputs,
|
||||
kernel_source);
|
||||
|
||||
std::vector<array> in_arrs;
|
||||
for (const auto& kv : inputs) {
|
||||
in_arrs.push_back(kv.second);
|
||||
}
|
||||
if (!template_args.empty()) {
|
||||
template_def = func_name.str() + template_def;
|
||||
kernel_source << std::endl
|
||||
<< "template [[host_name(\"" << kernel_name
|
||||
<< "\")]] [[kernel]] decltype(" << template_def << ") "
|
||||
<< template_def << ";" << std::endl;
|
||||
}
|
||||
|
||||
std::vector<std::string> out_keys;
|
||||
std::vector<std::vector<int>> out_shapes;
|
||||
for (const auto& [name, shape] : output_shapes) {
|
||||
out_keys.push_back(name);
|
||||
out_shapes.push_back(shape);
|
||||
}
|
||||
if (verbose) {
|
||||
std::cout << "Generated source code for `" << name << "`:" << std::endl
|
||||
<< "```" << std::endl
|
||||
<< kernel_source.str() << std::endl
|
||||
<< "```" << std::endl;
|
||||
}
|
||||
|
||||
std::vector<Dtype> out_dtypes;
|
||||
for (const auto& kv : output_dtypes) {
|
||||
out_dtypes.push_back(kv.second);
|
||||
}
|
||||
|
||||
std::map<std::string, array> outputs;
|
||||
auto outputs_vec = array::make_arrays(
|
||||
out_shapes,
|
||||
out_dtypes,
|
||||
std::make_shared<CustomKernel>(
|
||||
s,
|
||||
kernel_name,
|
||||
kernel_source.str(),
|
||||
grid,
|
||||
threadgroup,
|
||||
shape_infos,
|
||||
ensure_row_contiguous_,
|
||||
init_value),
|
||||
in_arrs);
|
||||
|
||||
int i = 0;
|
||||
for (const auto& key : out_keys) {
|
||||
outputs.insert({key, outputs_vec[i]});
|
||||
i++;
|
||||
}
|
||||
return outputs;
|
||||
return array::make_arrays(
|
||||
output_shapes,
|
||||
output_dtypes,
|
||||
std::make_shared<CustomKernel>(
|
||||
s,
|
||||
kernel_name,
|
||||
kernel_source.str(),
|
||||
grid,
|
||||
threadgroup,
|
||||
shape_infos,
|
||||
ensure_row_contiguous,
|
||||
init_value),
|
||||
inputs);
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fast
|
||||
|
53
mlx/fast.h
53
mlx/fast.h
@ -2,7 +2,6 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include <optional>
|
||||
|
||||
#include "mlx/utils.h"
|
||||
@ -39,7 +38,7 @@ array scaled_dot_product_attention(
|
||||
const array& values,
|
||||
const float scale,
|
||||
const std::optional<array>& mask = std::nullopt,
|
||||
const std::optional<int>& memory_efficient_threshold = std::nullopt,
|
||||
const std::optional<int> memory_efficient_threshold = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
std::tuple<array, array, array> affine_quantize(
|
||||
@ -66,37 +65,25 @@ array affine_dequantize(
|
||||
|
||||
typedef std::variant<int, bool, Dtype> TemplateArg;
|
||||
|
||||
class MetalKernel {
|
||||
public:
|
||||
MetalKernel(
|
||||
const std::string& name,
|
||||
const std::string& source,
|
||||
const std::string& header = "",
|
||||
bool ensure_row_contiguous = true,
|
||||
bool atomic_outputs = false)
|
||||
: name_(name),
|
||||
source_(source),
|
||||
header_(header),
|
||||
ensure_row_contiguous_(ensure_row_contiguous),
|
||||
atomic_outputs_(atomic_outputs) {}
|
||||
typedef std::function<std::vector<array>(
|
||||
const std::vector<array>&,
|
||||
const std::vector<std::vector<int>>&,
|
||||
const std::vector<Dtype>&,
|
||||
std::tuple<int, int, int>,
|
||||
std::tuple<int, int, int>,
|
||||
std::vector<std::pair<std::string, TemplateArg>>,
|
||||
std::optional<float>,
|
||||
bool,
|
||||
StreamOrDevice)>
|
||||
MetalKernelFunction;
|
||||
|
||||
std::map<std::string, array> operator()(
|
||||
std::map<std::string, array>& inputs,
|
||||
std::map<std::string, std::vector<int>> output_shapes,
|
||||
std::map<std::string, Dtype> output_dtypes,
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
std::optional<std::map<std::string, TemplateArg>> template_args =
|
||||
std::nullopt,
|
||||
std::optional<float> init_value = std::nullopt,
|
||||
bool verbose = false,
|
||||
StreamOrDevice s = {});
|
||||
MetalKernelFunction metal_kernel(
|
||||
const std::string& name,
|
||||
const std::vector<std::string>& input_names,
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::string& source,
|
||||
const std::string& header = "",
|
||||
bool ensure_row_contiguous = true,
|
||||
bool atomic_outputs = false);
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
std::string source_;
|
||||
std::string header_;
|
||||
bool ensure_row_contiguous_;
|
||||
bool atomic_outputs_;
|
||||
};
|
||||
} // namespace mlx::core::fast
|
||||
|
@ -1425,8 +1425,8 @@ array where(
|
||||
array nan_to_num(
|
||||
const array& a,
|
||||
float nan /* = 0.0f */,
|
||||
const std::optional<float>& posinf_ /* = std::nullopt */,
|
||||
const std::optional<float>& neginf_ /* = std::nullopt */,
|
||||
const std::optional<float> posinf_ /* = std::nullopt */,
|
||||
const std::optional<float> neginf_ /* = std::nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
Dtype dtype = a.dtype();
|
||||
if (!issubdtype(dtype, inexact)) {
|
||||
|
@ -416,8 +416,8 @@ array where(
|
||||
array nan_to_num(
|
||||
const array& a,
|
||||
float nan = 0.0f,
|
||||
const std::optional<float>& posinf = std::nullopt,
|
||||
const std::optional<float>& neginf = std::nullopt,
|
||||
const std::optional<float> posinf = std::nullopt,
|
||||
const std::optional<float> neginf = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** True if all elements in the array are true (or non-zero). **/
|
||||
|
@ -1,8 +1,8 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/map.h>
|
||||
#include <nanobind/stl/optional.h>
|
||||
#include <nanobind/stl/pair.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/tuple.h>
|
||||
#include <nanobind/stl/variant.h>
|
||||
@ -193,39 +193,130 @@ void init_fast(nb::module_& parent_module) {
|
||||
array: The quantized version of ``w``
|
||||
)pbdoc");
|
||||
|
||||
nb::class_<fast::MetalKernel>(
|
||||
m,
|
||||
m.def(
|
||||
"metal_kernel",
|
||||
[](const std::string& name,
|
||||
const std::vector<std::string>& input_names,
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::string& source,
|
||||
const std::string& header,
|
||||
bool ensure_row_contiguous,
|
||||
bool atomic_outputs) {
|
||||
auto kernel = fast::metal_kernel(
|
||||
name,
|
||||
input_names,
|
||||
output_names,
|
||||
source,
|
||||
header,
|
||||
ensure_row_contiguous,
|
||||
atomic_outputs);
|
||||
return nb::cpp_function(
|
||||
[kernel = std::move(kernel)](
|
||||
const std::vector<ScalarOrArray>& inputs_,
|
||||
const std::vector<std::vector<int>>& output_shapes,
|
||||
const std::vector<Dtype>& output_dtypes,
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
const std::optional<
|
||||
std::vector<std::pair<std::string, nb::object>>>&
|
||||
template_args_ = std::nullopt,
|
||||
std::optional<float> init_value = std::nullopt,
|
||||
bool verbose = false,
|
||||
StreamOrDevice s = {}) {
|
||||
std::vector<array> inputs;
|
||||
for (const auto& value : inputs_) {
|
||||
inputs.push_back(to_array(value, std::nullopt));
|
||||
}
|
||||
std::vector<std::pair<std::string, fast::TemplateArg>>
|
||||
template_args;
|
||||
if (template_args_) {
|
||||
for (const auto& [name, value] : template_args_.value()) {
|
||||
// Handle bool, int and dtype template args
|
||||
if (nb::isinstance<bool>(value)) {
|
||||
bool bool_val = nb::cast<bool>(value);
|
||||
template_args.emplace_back(name, bool_val);
|
||||
} else if (nb::isinstance<int>(value)) {
|
||||
int int_val = nb::cast<int>(value);
|
||||
template_args.emplace_back(name, int_val);
|
||||
} else if (nb::isinstance<Dtype>(value)) {
|
||||
Dtype dtype = nb::cast<Dtype>(value);
|
||||
template_args.emplace_back(name, dtype);
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[[metal_kernel]] Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`.");
|
||||
}
|
||||
}
|
||||
}
|
||||
return kernel(
|
||||
inputs,
|
||||
output_shapes,
|
||||
output_dtypes,
|
||||
grid,
|
||||
threadgroup,
|
||||
template_args,
|
||||
init_value,
|
||||
verbose,
|
||||
s);
|
||||
},
|
||||
nb::kw_only(),
|
||||
"inputs"_a,
|
||||
"output_shapes"_a,
|
||||
"output_dtypes"_a,
|
||||
"grid"_a,
|
||||
"threadgroup"_a,
|
||||
"template"_a = nb::none(),
|
||||
"init_value"_a = nb::none(),
|
||||
"verbose"_a = false,
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def __call__(self, *, inputs: List[Union[scalar, array]], output_shapes: List[Sequence[int]], output_dtypes: List[Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[List[Tuple[str, Union[bool, int, Dtype]]]] = None, init_value: Optional[float] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"),
|
||||
R"pbdoc(
|
||||
Run the kernel.
|
||||
|
||||
Args:
|
||||
inputs (List[array]): The inputs passed to the Metal kernel.
|
||||
output_shapes (List[Sequence[int]]): The list of shapes for each output in ``output_names``.
|
||||
output_dtypes (List[Dtype]): The list of data types for each output in ``output_names``.
|
||||
grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.
|
||||
threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.
|
||||
template (List[Tuple[str, Union[bool, int, Dtype]]], optional): Template arguments.
|
||||
These will be added as template arguments to the kernel definition. Default: ``None``.
|
||||
init_value (float, optional): Optional value to use to initialize all of the output arrays.
|
||||
By default, output arrays are uninitialized. Default: ``None``.
|
||||
verbose (bool, optional): Whether to print the full generated source code of the kernel
|
||||
when it is run. Default: ``False``.
|
||||
stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``.
|
||||
|
||||
Returns:
|
||||
List[array]: The list of output arrays.
|
||||
)pbdoc");
|
||||
},
|
||||
"name"_a,
|
||||
"input_names"_a,
|
||||
"output_names"_a,
|
||||
"source"_a,
|
||||
"header"_a = "",
|
||||
"ensure_row_contiguous"_a = true,
|
||||
"atomic_outputs"_a = false,
|
||||
R"pbdoc(
|
||||
A jit-compiled custom Metal kernel defined from a source string.
|
||||
)pbdoc")
|
||||
.def(
|
||||
nb::init<
|
||||
const std::string&,
|
||||
const std::string&,
|
||||
const std::string&,
|
||||
bool,
|
||||
bool>(),
|
||||
"name"_a,
|
||||
"source"_a,
|
||||
"header"_a = "",
|
||||
"ensure_row_contiguous"_a = true,
|
||||
"atomic_outputs"_a = false,
|
||||
R"pbdoc(
|
||||
Initialize a metal_kernel.
|
||||
|
||||
Args:
|
||||
name (str): Name for the kernel.
|
||||
source (str): Source code. This is the body of a function in Metal,
|
||||
the function signature will be generated for you. The names of the inputs/outputs
|
||||
are determined by the ``inputs`` and ``output_shapes``/``output_dtypes``
|
||||
used when the kernel is called.
|
||||
header (str): Header source code to include before the main function.
|
||||
Useful for helper functions or includes that should live outside of the main function body.
|
||||
ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous
|
||||
before the kernel runs. Default: ``True``.
|
||||
atomic_outputs (bool): Whether to use atomic outputs in the function signature
|
||||
e.g. ``device atomic<float>``. Default: ``False``.
|
||||
name (str): Name for the kernel.
|
||||
input_names (List[str]): The parameter names of the inputs in the
|
||||
function signature.
|
||||
output_names (List[str]): The parameter names of the outputs in the
|
||||
function signature.
|
||||
source (str): Source code. This is the body of a function in Metal,
|
||||
the function signature will be automatically generated.
|
||||
header (str): Header source code to include before the main function.
|
||||
Useful for helper functions or includes that should live outside of
|
||||
the main function body.
|
||||
ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous
|
||||
before the kernel runs. Default: ``True``.
|
||||
atomic_outputs (bool): Whether to use atomic outputs in the function signature
|
||||
e.g. ``device atomic<float>``. Default: ``False``.
|
||||
|
||||
Returns:
|
||||
Callable ``metal_kernel``.
|
||||
|
||||
@ -242,103 +333,23 @@ void init_fast(nb::module_& parent_module) {
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp",
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
source=source
|
||||
)
|
||||
outputs = kernel(
|
||||
inputs={"inp": a},
|
||||
template={"T": mx.float32},
|
||||
inputs=[a],
|
||||
template=[("T", mx.float32)],
|
||||
grid=(a.size, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
output_shapes={"out": a.shape},
|
||||
output_dtypes={"out": a.dtype},
|
||||
output_shapes=[a.shape],
|
||||
output_dtypes=[a.dtype],
|
||||
verbose=True,
|
||||
)
|
||||
return outputs["out"]
|
||||
return outputs[0]
|
||||
|
||||
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
|
||||
b = exp_elementwise(a)
|
||||
assert mx.allclose(b, mx.exp(a))
|
||||
)pbdoc")
|
||||
.def(
|
||||
"__call__",
|
||||
[](fast::MetalKernel& kernel,
|
||||
std::map<std::string, ScalarOrArray>& inputs_,
|
||||
std::map<std::string, std::vector<int>>& output_shapes,
|
||||
std::map<std::string, Dtype>& output_dtypes,
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
std::optional<std::map<std::string, nb::handle>> template_args_,
|
||||
std::optional<float> init_value,
|
||||
bool verbose,
|
||||
StreamOrDevice s) {
|
||||
std::map<std::string, array> inputs;
|
||||
for (const auto& [name, value] : inputs_) {
|
||||
auto arr = to_array(value, std::nullopt);
|
||||
inputs.insert({name, arr});
|
||||
}
|
||||
std::map<std::string, fast::TemplateArg> template_args;
|
||||
if (template_args_) {
|
||||
for (const auto& [name, value] : template_args_.value()) {
|
||||
// Handle bool, int and dtype template args
|
||||
if (nb::isinstance<bool>(value)) {
|
||||
bool bool_val = nb::cast<bool>(value);
|
||||
template_args.insert({name, bool_val});
|
||||
} else if (nb::isinstance<int>(value)) {
|
||||
int int_val = nb::cast<int>(value);
|
||||
template_args.insert({name, int_val});
|
||||
} else if (nb::isinstance<Dtype>(value)) {
|
||||
Dtype dtype = nb::cast<Dtype>(value);
|
||||
template_args.insert({name, dtype});
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[[metal_kernel]] Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`.");
|
||||
}
|
||||
}
|
||||
}
|
||||
return kernel(
|
||||
inputs,
|
||||
output_shapes,
|
||||
output_dtypes,
|
||||
grid,
|
||||
threadgroup,
|
||||
template_args,
|
||||
init_value,
|
||||
verbose,
|
||||
s);
|
||||
},
|
||||
nb::kw_only(),
|
||||
"inputs"_a,
|
||||
"output_shapes"_a,
|
||||
"output_dtypes"_a,
|
||||
"grid"_a,
|
||||
"threadgroup"_a,
|
||||
"template"_a = nb::none(),
|
||||
"init_value"_a = nb::none(),
|
||||
"verbose"_a = false,
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def __call__(self, *, inputs: Mapping[str, Union[scalar, array]], output_shapes: Mapping[str, Sequence[int]], output_dtypes: Mapping[str, Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[Mapping[str, Union[bool, int, Dtype]]] = None, init_value: Optional[float] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"),
|
||||
R"pbdoc(
|
||||
Run the kernel.
|
||||
|
||||
Args:
|
||||
inputs (Mapping[str, array]): Inputs. These will be added to the function signature and passed to the Metal kernel.
|
||||
The keys will be the names of the arguments to the kernel.
|
||||
output_shapes (Mapping[str, Sequence[int]]): Output shapes. A dict mapping
|
||||
output variable names to shapes. These will be added to the function signature.
|
||||
output_dtypes (Mapping[str, Dtype]): Output dtypes. A dict mapping output variable
|
||||
names to dtypes. Must have the same keys as ``output_shapes``.
|
||||
grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.
|
||||
threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.
|
||||
template (Mapping[str, Union[bool, int, Dtype]], optional): Template arguments.
|
||||
These will be added as template arguments to the kernel definition. Default: ``None``.
|
||||
init_value (float, optional): Optional value to use to initialize all of the output arrays.
|
||||
By default, output arrays are uninitialized. Default: ``None``.
|
||||
verbose (bool, optional): Whether to print the full generated source code of the kernel
|
||||
when it is run. Default: ``False``.
|
||||
stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``.
|
||||
|
||||
Returns:
|
||||
dict[str, array]: Dictionary of output arrays based on ``output_shapes``/``output_dtypes``.
|
||||
)pbdoc");
|
||||
)pbdoc");
|
||||
}
|
||||
|
@ -562,20 +562,22 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
a = mx.random.normal(shape=(2, 2))
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="basic",
|
||||
input_names=["a"],
|
||||
output_names=["out1"],
|
||||
source="""
|
||||
uint elem = thread_position_in_grid.x;
|
||||
out1[elem] = a[elem];
|
||||
""",
|
||||
)
|
||||
out = kernel(
|
||||
inputs={"a": a},
|
||||
inputs=[a],
|
||||
grid=(4, 1, 1),
|
||||
threadgroup=(2, 1, 1),
|
||||
output_shapes={"out1": (2, 2)},
|
||||
output_dtypes={"out1": mx.float32},
|
||||
output_shapes=[(2, 2)],
|
||||
output_dtypes=[mx.float32],
|
||||
stream=mx.gpu,
|
||||
)
|
||||
self.assertTrue(mx.allclose(out["out1"], a))
|
||||
self.assertTrue(mx.allclose(out[0], a))
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
def test_custom_kernel_args(self):
|
||||
@ -585,6 +587,8 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="arg_test",
|
||||
input_names=["a", "b", "c", "d"],
|
||||
output_names=["out1", "out2"],
|
||||
source="""
|
||||
uint elem = thread_position_in_grid.x;
|
||||
T tmp = a[0];
|
||||
@ -597,26 +601,26 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
""",
|
||||
)
|
||||
out = kernel(
|
||||
inputs={
|
||||
"a": a,
|
||||
"b": mx.array([3, 4, 5]),
|
||||
"c": c,
|
||||
"d": 7.3,
|
||||
},
|
||||
template={
|
||||
"e": True,
|
||||
"f": 3,
|
||||
"T": mx.float16,
|
||||
},
|
||||
inputs=[
|
||||
a,
|
||||
mx.array([3, 4, 5]),
|
||||
c,
|
||||
7.3,
|
||||
],
|
||||
template=[
|
||||
("e", True),
|
||||
("f", 3),
|
||||
("T", mx.float16),
|
||||
],
|
||||
grid=(6, 1, 1),
|
||||
threadgroup=(2, 1, 1),
|
||||
output_shapes={"out1": (2, 2), "out2": (3, 2)},
|
||||
output_dtypes={"out1": mx.float32, "out2": mx.int32},
|
||||
output_shapes=[(2, 2), (3, 2)],
|
||||
output_dtypes=[mx.float32, mx.int32],
|
||||
stream=mx.gpu,
|
||||
)
|
||||
|
||||
self.assertTrue(mx.allclose(out["out1"], mx.full((2, 2), 14.0484)))
|
||||
self.assertTrue(mx.allclose(out["out2"], mx.full((3, 2), -2, dtype=mx.int32)))
|
||||
self.assertTrue(mx.allclose(out[0], mx.full((2, 2), 14.0484)))
|
||||
self.assertTrue(mx.allclose(out[1], mx.full((3, 2), -2, dtype=mx.int32)))
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
def test_custom_kernel_strides(self):
|
||||
@ -640,19 +644,21 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
for contig in [True, False]:
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp" + str(contig),
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
source=source_contig if contig else source,
|
||||
ensure_row_contiguous=contig,
|
||||
)
|
||||
outputs = kernel(
|
||||
inputs={"inp": a},
|
||||
template={"T": mx.float32},
|
||||
inputs=[a],
|
||||
template=[("T", mx.float32)],
|
||||
grid=(a.size, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
output_shapes={"out": a.shape},
|
||||
output_dtypes={"out": a.dtype},
|
||||
output_shapes=[a.shape],
|
||||
output_dtypes=[a.dtype],
|
||||
stream=mx.gpu,
|
||||
)
|
||||
self.assertTrue(mx.allclose(mx.exp(a) * 32, outputs["out"]))
|
||||
self.assertTrue(mx.allclose(mx.exp(a) * 32, outputs[0]))
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
def test_custom_kernel_helper(self):
|
||||
@ -660,6 +666,8 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
a = mx.random.normal(shape=(2, 2))
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="helper",
|
||||
input_names=["a"],
|
||||
output_names=["out1"],
|
||||
header="""
|
||||
template <typename T>
|
||||
T do_exp(T x) {
|
||||
@ -672,14 +680,14 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
""",
|
||||
)
|
||||
out = kernel(
|
||||
inputs={"a": a},
|
||||
inputs=[a],
|
||||
grid=(4, 1, 1),
|
||||
threadgroup=(2, 1, 1),
|
||||
output_shapes={"out1": (2, 2)},
|
||||
output_dtypes={"out1": mx.float32},
|
||||
output_shapes=[(2, 2)],
|
||||
output_dtypes=[mx.float32],
|
||||
stream=mx.gpu,
|
||||
)
|
||||
self.assertTrue(mx.allclose(out["out1"], mx.exp(a)))
|
||||
self.assertTrue(mx.allclose(out[0], mx.exp(a)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
Reference in New Issue
Block a user