mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Fix unintuitive metal kernel caching (#2242)
* Fix unintuitive metal kernel caching * alternative solution
This commit is contained in:
parent
2e8cf0b450
commit
1ca616844b
@ -8,11 +8,12 @@ MLX supports writing custom Metal kernels through the Python and C++ APIs.
|
||||
Simple Example
|
||||
--------------
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
Let's write a custom kernel that computes ``exp`` elementwise:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def exp_elementwise(a: mx.array):
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
T tmp = inp[elem];
|
||||
@ -25,6 +26,8 @@ Let's write a custom kernel that computes ``exp`` elementwise:
|
||||
output_names=["out"],
|
||||
source=source,
|
||||
)
|
||||
|
||||
def exp_elementwise(a: mx.array):
|
||||
outputs = kernel(
|
||||
inputs=[a],
|
||||
template=[("T", mx.float32)],
|
||||
@ -39,8 +42,13 @@ Let's write a custom kernel that computes ``exp`` elementwise:
|
||||
b = exp_elementwise(a)
|
||||
assert mx.allclose(b, mx.exp(a))
|
||||
|
||||
Every time you make a kernel, a new Metal library is created and possibly
|
||||
JIT compiled. To reduce the overhead from that, build the kernel once with
|
||||
:func:`fast.metal_kernel` and then use it many times.
|
||||
|
||||
.. note::
|
||||
We are only required to pass the body of the Metal kernel in ``source``.
|
||||
Only pass the body of the Metal kernel in ``source``. The function
|
||||
signature is generated automatically.
|
||||
|
||||
The full function signature will be generated using:
|
||||
|
||||
@ -78,29 +86,34 @@ Putting this all together, the generated function signature for ``myexp`` is as
|
||||
|
||||
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
|
||||
|
||||
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads <https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_ function.
|
||||
This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups.
|
||||
For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.
|
||||
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads
|
||||
<https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_
|
||||
function. This means we will launch ``mx.prod(grid)`` threads, subdivided into
|
||||
``threadgroup`` size threadgroups. For optimal performance, each thread group
|
||||
dimension should be less than or equal to the corresponding grid dimension.
|
||||
|
||||
Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes.
|
||||
Passing ``verbose=True`` to :func:`ast.metal_kernel.__call__` will print the
|
||||
generated code for debugging purposes.
|
||||
|
||||
Using Shape/Strides
|
||||
-------------------
|
||||
|
||||
``mx.fast.metal_kernel`` supports an argument ``ensure_row_contiguous`` which is ``True`` by default.
|
||||
This will copy the ``mx.array`` inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous.
|
||||
Generally this makes writing the kernel easier, since we don't have to worry about gaps or the ordering of the dims
|
||||
when indexing.
|
||||
:func:`fast.metal_kernel` supports an argument ``ensure_row_contiguous`` which
|
||||
is ``True`` by default. This will copy the array inputs if needed
|
||||
before the kernel is launched to ensure that the memory layout is row
|
||||
contiguous. Generally this makes writing the kernel easier, since we don't
|
||||
have to worry about gaps or the ordering of the dims when indexing.
|
||||
|
||||
If we want to avoid this copy, ``metal_kernel`` automatically passes ``a_shape``, ``a_strides`` and ``a_ndim`` for each
|
||||
input array ``a`` if any are present in ``source``.
|
||||
We can then use MLX's built in indexing utils to fetch the right elements for each thread.
|
||||
If we want to avoid this copy, :func:`fast.metal_kernel` automatically passes
|
||||
``a_shape``, ``a_strides`` and ``a_ndim`` for each input array ``a`` if any are
|
||||
present in ``source``. We can then use MLX's built in indexing utils to fetch
|
||||
the right elements for each thread.
|
||||
|
||||
Let's convert ``myexp`` above to support arbitrarily strided arrays without relying on a copy from ``ensure_row_contiguous``:
|
||||
Let's convert ``myexp`` above to support arbitrarily strided arrays without
|
||||
relying on a copy from ``ensure_row_contiguous``:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def exp_elementwise(a: mx.array):
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
|
||||
@ -116,6 +129,8 @@ Let's convert ``myexp`` above to support arbitrarily strided arrays without rely
|
||||
output_names=["out"],
|
||||
source=source
|
||||
)
|
||||
|
||||
def exp_elementwise(a: mx.array):
|
||||
outputs = kernel(
|
||||
inputs=[a],
|
||||
template=[("T", mx.float32)],
|
||||
@ -183,25 +198,13 @@ We'll start with the following MLX implementation using standard ops:
|
||||
|
||||
return output
|
||||
|
||||
Now let's use ``mx.custom_function`` together with ``mx.fast.metal_kernel``
|
||||
Now let's use :func:`custom_function` together with :func:`fast.metal_kernel`
|
||||
to write a fast GPU kernel for both the forward and backward passes.
|
||||
|
||||
First we'll implement the forward pass as a fused kernel:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@mx.custom_function
|
||||
def grid_sample(x, grid):
|
||||
|
||||
assert x.ndim == 4, "`x` must be 4D."
|
||||
assert grid.ndim == 4, "`grid` must be 4D."
|
||||
|
||||
B, _, _, C = x.shape
|
||||
_, gN, gM, D = grid.shape
|
||||
out_shape = (B, gN, gM, C)
|
||||
|
||||
assert D == 2, "Last dim of `grid` must be size 2."
|
||||
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
int H = x_shape[1];
|
||||
@ -251,12 +254,26 @@ First we'll implement the forward pass as a fused kernel:
|
||||
|
||||
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
|
||||
"""
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="grid_sample",
|
||||
input_names=["x", "grid"],
|
||||
output_names=["out"],
|
||||
source=source,
|
||||
)
|
||||
|
||||
@mx.custom_function
|
||||
def grid_sample(x, grid):
|
||||
|
||||
assert x.ndim == 4, "`x` must be 4D."
|
||||
assert grid.ndim == 4, "`grid` must be 4D."
|
||||
|
||||
B, _, _, C = x.shape
|
||||
_, gN, gM, D = grid.shape
|
||||
out_shape = (B, gN, gM, C)
|
||||
|
||||
assert D == 2, "Last dim of `grid` must be size 2."
|
||||
|
||||
outputs = kernel(
|
||||
inputs=[x, grid],
|
||||
template=[("T", x.dtype)],
|
||||
@ -281,11 +298,11 @@ On an M1 Max, we see a big performance improvement:
|
||||
Grid Sample VJP
|
||||
---------------
|
||||
|
||||
Since we decorated ``grid_sample`` with ``mx.custom_function``, we can now define
|
||||
its custom vjp transform so MLX can differentiate it.
|
||||
Since we decorated ``grid_sample`` with :func:`custom_function`, we can now
|
||||
define its custom vjp transform so MLX can differentiate it.
|
||||
|
||||
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
|
||||
requires a few extra ``mx.fast.metal_kernel`` features:
|
||||
requires a few extra :func:`fast.metal_kernel` features:
|
||||
|
||||
* ``init_value=0``
|
||||
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
|
||||
@ -299,14 +316,6 @@ We can then implement the backwards pass as follows:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@grid_sample.vjp
|
||||
def grid_sample_vjp(primals, cotangent, _):
|
||||
x, grid = primals
|
||||
B, _, _, C = x.shape
|
||||
_, gN, gM, D = grid.shape
|
||||
|
||||
assert D == 2, "Last dim of `grid` must be size 2."
|
||||
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
int H = x_shape[1];
|
||||
@ -406,6 +415,15 @@ We can then implement the backwards pass as follows:
|
||||
source=source,
|
||||
atomic_outputs=True,
|
||||
)
|
||||
|
||||
@grid_sample.vjp
|
||||
def grid_sample_vjp(primals, cotangent, _):
|
||||
x, grid = primals
|
||||
B, _, _, C = x.shape
|
||||
_, gN, gM, D = grid.shape
|
||||
|
||||
assert D == 2, "Last dim of `grid` must be size 2."
|
||||
|
||||
# pad the output channels to simd group size
|
||||
# so that our `simd_sum`s don't overlap.
|
||||
simdgroup_size = 32
|
||||
|
@ -397,11 +397,11 @@ below.
|
||||
std::ostringstream kname;
|
||||
kname << "axpby_" << "general_" << type_to_name(out);
|
||||
|
||||
// Make sure the metal library is available
|
||||
d.register_library("mlx_ext");
|
||||
// Load the metal library
|
||||
auto lib = d.get_library("mlx_ext");
|
||||
|
||||
// Make a kernel from this metal library
|
||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
||||
auto kernel = d.get_kernel(kname.str(), lib);
|
||||
|
||||
// Prepare to encode kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
|
@ -172,11 +172,11 @@ void Axpby::eval_gpu(
|
||||
kname << (contiguous_kernel ? "contiguous_" : "general_");
|
||||
kname << type_to_name(out);
|
||||
|
||||
// Make sure the metal library is available
|
||||
d.register_library("mlx_ext");
|
||||
// Load the metal library
|
||||
auto lib = d.get_library("mlx_ext");
|
||||
|
||||
// Make a kernel from this metal library
|
||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
||||
auto kernel = d.get_kernel(kname.str(), lib);
|
||||
|
||||
// Prepare to encode kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
|
@ -677,7 +677,7 @@ void depthwise_conv_2D_gpu(
|
||||
std::string hash_name = kname.str();
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
|
||||
auto kernel = d.get_kernel(base_name, hash_name, func_consts);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
|
@ -1,12 +1,326 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <iostream>
|
||||
#include <regex>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/fast.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core::fast {
|
||||
|
||||
struct CustomKernelCache {
|
||||
std::unordered_map<std::string, std::string> libraries;
|
||||
};
|
||||
|
||||
static CustomKernelCache& cache() {
|
||||
static CustomKernelCache cache_;
|
||||
return cache_;
|
||||
};
|
||||
|
||||
std::string write_signature(
|
||||
std::string func_name,
|
||||
const std::string& header,
|
||||
const std::string& source,
|
||||
const std::vector<std::string>& input_names,
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::vector<Dtype>& output_dtypes,
|
||||
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
|
||||
const std::vector<std::string>& attributes,
|
||||
const std::vector<CustomKernelShapeInfo>& shape_infos,
|
||||
bool atomic_outputs) {
|
||||
std::string kernel_source;
|
||||
kernel_source.reserve(header.size() + source.size() + 16384);
|
||||
kernel_source += header;
|
||||
// Auto-generate a function signature based on `template_args`
|
||||
// and the dtype/shape of the arrays passed as `inputs`.
|
||||
if (!template_args.empty()) {
|
||||
kernel_source += "template <";
|
||||
int i = 0;
|
||||
for (const auto& [name, arg] : template_args) {
|
||||
std::string param_type;
|
||||
if (std::holds_alternative<int>(arg)) {
|
||||
param_type = "int";
|
||||
} else if (std::holds_alternative<bool>(arg)) {
|
||||
param_type = "bool";
|
||||
} else if (std::holds_alternative<Dtype>(arg)) {
|
||||
param_type = "typename";
|
||||
}
|
||||
if (i > 0) {
|
||||
kernel_source += ", ";
|
||||
}
|
||||
kernel_source += param_type;
|
||||
kernel_source += " ";
|
||||
kernel_source += name;
|
||||
i++;
|
||||
}
|
||||
kernel_source += ">\n";
|
||||
}
|
||||
kernel_source += "[[kernel]] void ";
|
||||
kernel_source += func_name;
|
||||
kernel_source += "(\n";
|
||||
|
||||
int index = 0;
|
||||
constexpr int max_constant_array_size = 8;
|
||||
// Add inputs
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
const auto& name = input_names[i];
|
||||
const auto& arr = inputs[i];
|
||||
auto dtype = get_type_string(arr.dtype());
|
||||
std::string location =
|
||||
arr.size() < max_constant_array_size ? "constant" : "device";
|
||||
std::string ref = arr.ndim() == 0 ? "&" : "*";
|
||||
kernel_source += " const ";
|
||||
kernel_source += location;
|
||||
kernel_source += " ";
|
||||
kernel_source += dtype;
|
||||
kernel_source += ref;
|
||||
kernel_source += " ";
|
||||
kernel_source += name;
|
||||
kernel_source += " [[buffer(";
|
||||
kernel_source += std::to_string(index);
|
||||
kernel_source += ")]],\n";
|
||||
index++;
|
||||
// Add input shape, strides and ndim if present in the source
|
||||
if (arr.ndim() > 0) {
|
||||
if (shape_infos[i].shape) {
|
||||
kernel_source +=
|
||||
(" const constant int* " + name + "_shape [[buffer(" +
|
||||
std::to_string(index) + ")]],\n");
|
||||
index++;
|
||||
}
|
||||
if (shape_infos[i].strides) {
|
||||
kernel_source +=
|
||||
(" const constant int64_t* " + name + "_strides [[buffer(" +
|
||||
std::to_string(index) + ")]],\n");
|
||||
index++;
|
||||
}
|
||||
if (shape_infos[i].ndim) {
|
||||
kernel_source +=
|
||||
(" const constant int& " + name + "_ndim [[buffer(" +
|
||||
std::to_string(index) + ")]],\n");
|
||||
index++;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Add outputs
|
||||
for (int i = 0; i < output_names.size(); ++i) {
|
||||
const auto& name = output_names[i];
|
||||
const auto& dtype = output_dtypes[i];
|
||||
kernel_source += " device ";
|
||||
auto type_string = get_type_string(dtype);
|
||||
if (atomic_outputs) {
|
||||
kernel_source += "atomic<";
|
||||
}
|
||||
kernel_source += type_string;
|
||||
if (atomic_outputs) {
|
||||
kernel_source += ">";
|
||||
}
|
||||
kernel_source += "* ";
|
||||
kernel_source += name;
|
||||
kernel_source += " [[buffer(";
|
||||
kernel_source += std::to_string(index);
|
||||
kernel_source += ")]]";
|
||||
if (index < inputs.size() + output_names.size() - 1 ||
|
||||
attributes.size() > 0) {
|
||||
kernel_source += ",\n";
|
||||
} else {
|
||||
kernel_source += ") {\n";
|
||||
}
|
||||
index++;
|
||||
}
|
||||
|
||||
index = 0;
|
||||
for (const auto& attr : attributes) {
|
||||
kernel_source += attr;
|
||||
if (index < attributes.size() - 1) {
|
||||
kernel_source += ",\n";
|
||||
} else {
|
||||
kernel_source += ") {\n";
|
||||
}
|
||||
index++;
|
||||
}
|
||||
kernel_source += source;
|
||||
kernel_source += "\n}\n";
|
||||
return kernel_source;
|
||||
}
|
||||
|
||||
std::string write_template(
|
||||
const std::vector<std::pair<std::string, TemplateArg>>& template_args) {
|
||||
std::ostringstream template_def;
|
||||
template_def << "<";
|
||||
int i = 0;
|
||||
for (const auto& [name, arg] : template_args) {
|
||||
if (i > 0) {
|
||||
template_def << ", ";
|
||||
}
|
||||
if (std::holds_alternative<int>(arg)) {
|
||||
template_def << std::get<int>(arg);
|
||||
} else if (std::holds_alternative<bool>(arg)) {
|
||||
template_def << std::get<bool>(arg);
|
||||
} else if (std::holds_alternative<Dtype>(arg)) {
|
||||
template_def << get_type_string(std::get<Dtype>(arg));
|
||||
}
|
||||
i++;
|
||||
}
|
||||
template_def << ">";
|
||||
return template_def.str();
|
||||
}
|
||||
|
||||
MetalKernelFunction metal_kernel(
|
||||
const std::string& name,
|
||||
const std::vector<std::string>& input_names,
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::string& source,
|
||||
const std::string& header /* = "" */,
|
||||
bool ensure_row_contiguous /* = true */,
|
||||
bool atomic_outputs /* = false */) {
|
||||
if (output_names.empty()) {
|
||||
throw std::invalid_argument(
|
||||
"[metal_kernel] Must specify at least one output.");
|
||||
}
|
||||
std::vector<CustomKernelShapeInfo> shape_infos;
|
||||
for (auto& n : input_names) {
|
||||
CustomKernelShapeInfo shape_info;
|
||||
shape_info.shape = source.find(n + "_shape") != std::string::npos;
|
||||
shape_info.strides = source.find(n + "_strides") != std::string::npos;
|
||||
shape_info.ndim = source.find(n + "_ndim") != std::string::npos;
|
||||
shape_infos.push_back(shape_info);
|
||||
}
|
||||
const std::vector<std::pair<std::string, std::string>> metal_attributes = {
|
||||
{"dispatch_quadgroups_per_threadgroup", "uint"},
|
||||
{"dispatch_simdgroups_per_threadgroup", "uint"},
|
||||
{"dispatch_threads_per_threadgroup", "uint3"},
|
||||
{"grid_origin", "uint3"},
|
||||
{"grid_size", "uint3"},
|
||||
{"quadgroup_index_in_threadgroup", "uint"},
|
||||
{"quadgroups_per_threadgroup", "uint"},
|
||||
{"simdgroup_index_in_threadgroup", "uint"},
|
||||
{"simdgroups_per_threadgroup", "uint"},
|
||||
{"thread_execution_width", "uint"},
|
||||
{"thread_index_in_quadgroup", "uint"},
|
||||
{"thread_index_in_simdgroup", "uint"},
|
||||
{"thread_index_in_threadgroup", "uint"},
|
||||
{"thread_position_in_grid", "uint3"},
|
||||
{"thread_position_in_threadgroup", "uint3"},
|
||||
{"threadgroup_position_in_grid", "uint3"},
|
||||
{"threadgroups_per_grid", "uint3"},
|
||||
{"threads_per_grid", "uint3"},
|
||||
{"threads_per_simdgroup", "uint"},
|
||||
{"threads_per_threadgroup", "uint3"},
|
||||
};
|
||||
|
||||
std::vector<std::string> attributes;
|
||||
for (const auto& [attr, dtype] : metal_attributes) {
|
||||
if (source.find(attr) != std::string::npos) {
|
||||
attributes.push_back(" " + dtype + " " + attr + " [[" + attr + "]]");
|
||||
}
|
||||
}
|
||||
|
||||
return [=,
|
||||
shape_infos = std::move(shape_infos),
|
||||
attributes = std::move(attributes)](
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<Shape>& output_shapes,
|
||||
const std::vector<Dtype>& output_dtypes,
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
const std::vector<std::pair<std::string, TemplateArg>>&
|
||||
template_args = {},
|
||||
std::optional<float> init_value = std::nullopt,
|
||||
bool verbose = false,
|
||||
StreamOrDevice s_ = {}) {
|
||||
if (inputs.size() != input_names.size()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal_kernel] Expected `inputs` to have size "
|
||||
<< input_names.size() << " but got size " << inputs.size() << "."
|
||||
<< std::endl;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (output_shapes.size() != output_names.size()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal_kernel] Expected `output_shapes` to have size "
|
||||
<< output_names.size() << " but got size " << output_shapes.size()
|
||||
<< "." << std::endl;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (output_dtypes.size() != output_names.size()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal_kernel] Expected `output_dtypes` to have size "
|
||||
<< output_names.size() << " but got size " << output_dtypes.size()
|
||||
<< "." << std::endl;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto s = to_stream(s_);
|
||||
if (s.device != Device::gpu) {
|
||||
throw std::invalid_argument("[metal_kernel] Only supports the GPU.");
|
||||
}
|
||||
|
||||
std::string kernel_name = "custom_kernel_" + name;
|
||||
std::string template_def = "";
|
||||
if (!template_args.empty()) {
|
||||
std::regex disallowed_chars("\\<|\\>|(, )");
|
||||
template_def = write_template(template_args);
|
||||
auto template_hash =
|
||||
std::regex_replace(template_def, disallowed_chars, "_");
|
||||
template_hash.pop_back();
|
||||
kernel_name += "_";
|
||||
kernel_name += template_hash;
|
||||
}
|
||||
|
||||
std::string kernel_source = write_signature(
|
||||
kernel_name,
|
||||
header,
|
||||
source,
|
||||
input_names,
|
||||
inputs,
|
||||
output_names,
|
||||
output_dtypes,
|
||||
template_args,
|
||||
attributes,
|
||||
shape_infos,
|
||||
atomic_outputs);
|
||||
|
||||
if (!template_args.empty()) {
|
||||
template_def = kernel_name + template_def;
|
||||
kernel_source += "\ntemplate [[host_name(\"";
|
||||
kernel_source += kernel_name;
|
||||
kernel_source += "\")]] [[kernel]] decltype(";
|
||||
kernel_source += template_def;
|
||||
kernel_source += ") ";
|
||||
kernel_source += template_def;
|
||||
kernel_source += ";\n";
|
||||
}
|
||||
|
||||
if (verbose) {
|
||||
std::cout << "Generated source code for `" << name << "`:" << std::endl
|
||||
<< "```" << std::endl
|
||||
<< kernel_source << std::endl
|
||||
<< "```" << std::endl;
|
||||
}
|
||||
|
||||
return array::make_arrays(
|
||||
std::move(output_shapes),
|
||||
std::move(output_dtypes),
|
||||
std::make_shared<CustomKernel>(
|
||||
s,
|
||||
std::move(kernel_name),
|
||||
std::move(kernel_source),
|
||||
grid,
|
||||
threadgroup,
|
||||
shape_infos,
|
||||
ensure_row_contiguous,
|
||||
init_value),
|
||||
std::move(inputs));
|
||||
};
|
||||
}
|
||||
|
||||
void CustomKernel::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
@ -39,9 +353,23 @@ void CustomKernel::eval_gpu(
|
||||
}
|
||||
|
||||
auto& d = metal::device(s.device);
|
||||
const auto& lib_name = name_;
|
||||
auto lib =
|
||||
d.get_library(lib_name, [this] { return metal::utils() + source_; });
|
||||
|
||||
{
|
||||
// Clear kernels from the device library cache if needed
|
||||
auto& kernel_cache = cache();
|
||||
if (auto it = kernel_cache.libraries.find(name_);
|
||||
it != kernel_cache.libraries.end()) {
|
||||
if (it->second != source_) {
|
||||
auto& d = metal::device(s.device);
|
||||
d.clear_library(name_);
|
||||
it->second = source_;
|
||||
}
|
||||
} else {
|
||||
kernel_cache.libraries.emplace(name_, source_);
|
||||
}
|
||||
}
|
||||
|
||||
auto lib = d.get_library(name_, [this] { return metal::utils() + source_; });
|
||||
auto kernel = d.get_kernel(name_, lib);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
@ -73,6 +401,16 @@ void CustomKernel::eval_gpu(
|
||||
}
|
||||
|
||||
const auto [tx, ty, tz] = threadgroup_;
|
||||
auto tg_size = tx * ty * tz;
|
||||
auto max_tg_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (tg_size > max_tg_size) {
|
||||
std::ostringstream msg;
|
||||
msg << "Thread group size (" << tg_size << ") is greater than "
|
||||
<< " the maximum allowed threads per threadgroup (" << max_tg_size
|
||||
<< ").";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
const auto [gx, gy, gz] = grid_;
|
||||
MTL::Size group_dims =
|
||||
MTL::Size(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz));
|
||||
|
@ -295,7 +295,7 @@ void CommandEncoder::barrier() {
|
||||
Device::Device() {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
device_ = load_device();
|
||||
library_map_ = {{"mlx", load_default_library(device_)}};
|
||||
default_library_ = load_default_library(device_);
|
||||
arch_ = std::string(device_->architecture()->name()->utf8String());
|
||||
auto arch = arch_.back();
|
||||
switch (arch) {
|
||||
@ -326,11 +326,11 @@ Device::Device() {
|
||||
|
||||
Device::~Device() {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
for (auto& k : kernel_map_) {
|
||||
k.second->release();
|
||||
for (auto& [l, kernel_map] : library_kernels_) {
|
||||
l->release();
|
||||
for (auto& [_, k] : kernel_map) {
|
||||
k->release();
|
||||
}
|
||||
for (auto& l : library_map_) {
|
||||
l.second->release();
|
||||
}
|
||||
stream_map_.clear();
|
||||
device_->release();
|
||||
@ -474,13 +474,24 @@ CommandEncoder& Device::get_command_encoder(int index) {
|
||||
return *stream.encoder;
|
||||
}
|
||||
|
||||
void Device::register_library(
|
||||
const std::string& lib_name,
|
||||
const std::string& lib_path) {
|
||||
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
|
||||
auto new_lib = load_library(device_, lib_name, lib_path.c_str());
|
||||
library_map_.insert({lib_name, new_lib});
|
||||
MTL::Library* Device::get_library(
|
||||
const std::string& name,
|
||||
const std::string& path /* = "" */) {
|
||||
{
|
||||
std::shared_lock rlock(library_mtx_);
|
||||
if (auto it = library_map_.find(name); it != library_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_lock wlock(library_mtx_);
|
||||
if (auto it = library_map_.find(name); it != library_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
auto new_lib = load_library(device_, name, path.c_str());
|
||||
library_map_.insert({name, new_lib});
|
||||
return new_lib;
|
||||
}
|
||||
|
||||
MTL::Library* Device::build_library_(const std::string& source_string) {
|
||||
@ -649,6 +660,19 @@ MTL::Library* Device::get_library(
|
||||
return mtl_lib;
|
||||
}
|
||||
|
||||
void Device::clear_library(const std::string& name) {
|
||||
std::unique_lock wlock(library_mtx_);
|
||||
if (auto it = library_map_.find(name); it != library_map_.end()) {
|
||||
auto kernel_map_it = library_kernels_.find(it->second);
|
||||
for (auto& [_, kernel] : kernel_map_it->second) {
|
||||
kernel->release();
|
||||
}
|
||||
library_kernels_.erase(kernel_map_it);
|
||||
it->second->release();
|
||||
library_map_.erase(it);
|
||||
}
|
||||
}
|
||||
|
||||
MTL::LinkedFunctions* Device::get_linked_functions_(
|
||||
const std::vector<MTL::Function*>& funcs) {
|
||||
if (funcs.empty()) {
|
||||
@ -679,6 +703,7 @@ MTL::ComputePipelineState* Device::get_kernel_(
|
||||
std::unique_lock wlock(kernel_mtx_);
|
||||
|
||||
// Try loading again to avoid loading twice
|
||||
auto& kernel_map_ = library_kernels_[mtl_lib];
|
||||
if (auto it = kernel_map_.find(hash_name); it != kernel_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
@ -713,6 +738,7 @@ MTL::ComputePipelineState* Device::get_kernel(
|
||||
std::shared_lock lock(kernel_mtx_);
|
||||
|
||||
// Look for cached kernel
|
||||
auto& kernel_map_ = library_kernels_[mtl_lib];
|
||||
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
@ -722,23 +748,11 @@ MTL::ComputePipelineState* Device::get_kernel(
|
||||
|
||||
MTL::ComputePipelineState* Device::get_kernel(
|
||||
const std::string& base_name,
|
||||
const std::string& lib_name /* = "mlx" */,
|
||||
const std::string& hash_name /* = "" */,
|
||||
const MTLFCList& func_consts /* = {} */,
|
||||
const std::vector<MTL::Function*>& linked_functions /* = {} */) {
|
||||
const auto& kname = hash_name.size() == 0 ? base_name : hash_name;
|
||||
{
|
||||
// Multiple readers allowed
|
||||
std::shared_lock lock(kernel_mtx_);
|
||||
|
||||
// Look for cached kernel
|
||||
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
// Search for cached metal lib
|
||||
MTL::Library* mtl_lib = get_library_(lib_name);
|
||||
return get_kernel_(base_name, mtl_lib, kname, func_consts, linked_functions);
|
||||
return get_kernel(
|
||||
base_name, default_library_, hash_name, func_consts, linked_functions);
|
||||
}
|
||||
|
||||
void Device::set_residency_set(const MTL::ResidencySet* residency_set) {
|
||||
|
@ -187,14 +187,16 @@ class Device {
|
||||
CommandEncoder& get_command_encoder(int index);
|
||||
void end_encoding(int index);
|
||||
|
||||
void register_library(
|
||||
const std::string& lib_name,
|
||||
const std::string& lib_path = "");
|
||||
MTL::Library* get_library(
|
||||
const std::string& name,
|
||||
const std::string& path = "");
|
||||
|
||||
MTL::Library* get_library(
|
||||
const std::string& name,
|
||||
const std::function<std::string(void)>& builder);
|
||||
|
||||
void clear_library(const std::string& name);
|
||||
|
||||
MTL::ComputePipelineState* get_kernel(
|
||||
const std::string& base_name,
|
||||
MTL::Library* mtl_lib,
|
||||
@ -204,7 +206,6 @@ class Device {
|
||||
|
||||
MTL::ComputePipelineState* get_kernel(
|
||||
const std::string& base_name,
|
||||
const std::string& lib_name = "mlx",
|
||||
const std::string& hash_name = "",
|
||||
const MTLFCList& func_consts = {},
|
||||
const std::vector<MTL::Function*>& linked_functions = {});
|
||||
@ -258,10 +259,13 @@ class Device {
|
||||
std::unordered_map<int32_t, DeviceStream> stream_map_;
|
||||
|
||||
std::shared_mutex kernel_mtx_;
|
||||
std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
|
||||
|
||||
std::shared_mutex library_mtx_;
|
||||
std::unordered_map<std::string, MTL::Library*> library_map_;
|
||||
MTL::Library* default_library_;
|
||||
std::unordered_map<
|
||||
MTL::Library*,
|
||||
std::unordered_map<std::string, MTL::ComputePipelineState*>>
|
||||
library_kernels_;
|
||||
const MTL::ResidencySet* residency_set_{nullptr};
|
||||
std::string arch_;
|
||||
int max_ops_per_buffer_;
|
||||
|
@ -146,7 +146,7 @@ MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
|
||||
int,
|
||||
int,
|
||||
int) {
|
||||
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
|
||||
return d.get_kernel(kernel_name, hash_name, func_consts);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
|
||||
@ -207,7 +207,7 @@ MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
|
||||
int,
|
||||
int,
|
||||
bool) {
|
||||
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
|
||||
return d.get_kernel(kernel_name, hash_name, func_consts);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_gemv_masked_kernel(
|
||||
@ -259,7 +259,7 @@ MTL::ComputePipelineState* get_fft_kernel(
|
||||
const std::string& hash_name,
|
||||
const metal::MTLFCList& func_consts,
|
||||
const std::string&) {
|
||||
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
|
||||
return d.get_kernel(kernel_name, hash_name, func_consts);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_quantized_kernel(
|
||||
@ -283,7 +283,7 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
|
||||
int,
|
||||
int,
|
||||
bool) {
|
||||
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
|
||||
return d.get_kernel(kernel_name, hash_name, func_consts);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -172,7 +172,7 @@ void RMSNormVJP::eval_gpu(
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
{
|
||||
auto kernel = d.get_kernel(op_name, "mlx", hash_name, func_consts);
|
||||
auto kernel = d.get_kernel(op_name, hash_name, func_consts);
|
||||
|
||||
MTL::Size grid_dims, group_dims;
|
||||
if (axis_size <= looped_limit) {
|
||||
@ -395,7 +395,7 @@ void LayerNormVJP::eval_gpu(
|
||||
};
|
||||
|
||||
{
|
||||
auto kernel = d.get_kernel(op_name, "mlx", hash_name, func_consts);
|
||||
auto kernel = d.get_kernel(op_name, hash_name, func_consts);
|
||||
|
||||
MTL::Size grid_dims, group_dims;
|
||||
if (axis_size <= looped_limit) {
|
||||
|
@ -73,7 +73,7 @@ void sdpa_full_self_attention_metal(
|
||||
std::string hash_name = kname.str();
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
|
||||
auto kernel = d.get_kernel(base_name, hash_name, func_consts);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
const int NQ = (qL + bq - 1) / bq;
|
||||
@ -180,7 +180,7 @@ void sdpa_vector(
|
||||
|
||||
// Get the kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname, "mlx", hash_name, func_consts);
|
||||
auto kernel = d.get_kernel(kname, hash_name, func_consts);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Set its arguments
|
||||
@ -281,7 +281,7 @@ void sdpa_vector_2pass(
|
||||
|
||||
// Get the kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname, "mlx", hash_name, func_consts);
|
||||
auto kernel = d.get_kernel(kname, hash_name, func_consts);
|
||||
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/distributed/primitives.h"
|
||||
#include "mlx/fast.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
#define NO_GPU_MULTI(func) \
|
||||
@ -155,6 +156,18 @@ NO_GPU_USE_FALLBACK(RoPE)
|
||||
NO_GPU(ScaledDotProductAttention)
|
||||
NO_GPU_MULTI(AffineQuantize)
|
||||
NO_GPU_MULTI(CustomKernel)
|
||||
|
||||
MetalKernelFunction metal_kernel(
|
||||
const std::string&,
|
||||
const std::vector<std::string>&,
|
||||
const std::vector<std::string>&,
|
||||
const std::string&,
|
||||
const std::string&,
|
||||
bool ensure_row_contiguous,
|
||||
bool atomic_outputs) {
|
||||
throw std::runtime_error("[metal_kernel] No GPU back-end.");
|
||||
}
|
||||
|
||||
} // namespace fast
|
||||
|
||||
namespace distributed {
|
||||
|
302
mlx/fast.cpp
302
mlx/fast.cpp
@ -1,10 +1,7 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <regex>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/fast.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/ops.h"
|
||||
@ -1027,303 +1024,4 @@ std::vector<Shape> AffineQuantize::output_shapes(
|
||||
}
|
||||
}
|
||||
|
||||
std::string write_signature(
|
||||
std::string func_name,
|
||||
const std::string& header,
|
||||
const std::string& source,
|
||||
const std::vector<std::string>& input_names,
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::vector<Dtype>& output_dtypes,
|
||||
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
|
||||
const std::vector<std::string>& attributes,
|
||||
const std::vector<CustomKernelShapeInfo>& shape_infos,
|
||||
bool atomic_outputs) {
|
||||
std::string kernel_source;
|
||||
kernel_source.reserve(header.size() + source.size() + 16384);
|
||||
kernel_source += header;
|
||||
// Auto-generate a function signature based on `template_args`
|
||||
// and the dtype/shape of the arrays passed as `inputs`.
|
||||
if (!template_args.empty()) {
|
||||
kernel_source += "template <";
|
||||
int i = 0;
|
||||
for (const auto& [name, arg] : template_args) {
|
||||
std::string param_type;
|
||||
if (std::holds_alternative<int>(arg)) {
|
||||
param_type = "int";
|
||||
} else if (std::holds_alternative<bool>(arg)) {
|
||||
param_type = "bool";
|
||||
} else if (std::holds_alternative<Dtype>(arg)) {
|
||||
param_type = "typename";
|
||||
}
|
||||
if (i > 0) {
|
||||
kernel_source += ", ";
|
||||
}
|
||||
kernel_source += param_type;
|
||||
kernel_source += " ";
|
||||
kernel_source += name;
|
||||
i++;
|
||||
}
|
||||
kernel_source += ">\n";
|
||||
}
|
||||
kernel_source += "[[kernel]] void ";
|
||||
kernel_source += func_name;
|
||||
kernel_source += "(\n";
|
||||
|
||||
int index = 0;
|
||||
constexpr int max_constant_array_size = 8;
|
||||
// Add inputs
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
const auto& name = input_names[i];
|
||||
const auto& arr = inputs[i];
|
||||
auto dtype = get_type_string(arr.dtype());
|
||||
std::string location =
|
||||
arr.size() < max_constant_array_size ? "constant" : "device";
|
||||
std::string ref = arr.ndim() == 0 ? "&" : "*";
|
||||
kernel_source += " const ";
|
||||
kernel_source += location;
|
||||
kernel_source += " ";
|
||||
kernel_source += dtype;
|
||||
kernel_source += ref;
|
||||
kernel_source += " ";
|
||||
kernel_source += name;
|
||||
kernel_source += " [[buffer(";
|
||||
kernel_source += std::to_string(index);
|
||||
kernel_source += ")]],\n";
|
||||
index++;
|
||||
// Add input shape, strides and ndim if present in the source
|
||||
if (arr.ndim() > 0) {
|
||||
if (shape_infos[i].shape) {
|
||||
kernel_source +=
|
||||
(" const constant int* " + name + "_shape [[buffer(" +
|
||||
std::to_string(index) + ")]],\n");
|
||||
index++;
|
||||
}
|
||||
if (shape_infos[i].strides) {
|
||||
kernel_source +=
|
||||
(" const constant int64_t* " + name + "_strides [[buffer(" +
|
||||
std::to_string(index) + ")]],\n");
|
||||
index++;
|
||||
}
|
||||
if (shape_infos[i].ndim) {
|
||||
kernel_source +=
|
||||
(" const constant int& " + name + "_ndim [[buffer(" +
|
||||
std::to_string(index) + ")]],\n");
|
||||
index++;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Add outputs
|
||||
for (int i = 0; i < output_names.size(); ++i) {
|
||||
const auto& name = output_names[i];
|
||||
const auto& dtype = output_dtypes[i];
|
||||
kernel_source += " device ";
|
||||
auto type_string = get_type_string(dtype);
|
||||
if (atomic_outputs) {
|
||||
kernel_source += "atomic<";
|
||||
}
|
||||
kernel_source += type_string;
|
||||
if (atomic_outputs) {
|
||||
kernel_source += ">";
|
||||
}
|
||||
kernel_source += "* ";
|
||||
kernel_source += name;
|
||||
kernel_source += " [[buffer(";
|
||||
kernel_source += std::to_string(index);
|
||||
kernel_source += ")]]";
|
||||
if (index < inputs.size() + output_names.size() - 1 ||
|
||||
attributes.size() > 0) {
|
||||
kernel_source += ",\n";
|
||||
} else {
|
||||
kernel_source += ") {\n";
|
||||
}
|
||||
index++;
|
||||
}
|
||||
|
||||
index = 0;
|
||||
for (const auto& attr : attributes) {
|
||||
kernel_source += attr;
|
||||
if (index < attributes.size() - 1) {
|
||||
kernel_source += ",\n";
|
||||
} else {
|
||||
kernel_source += ") {\n";
|
||||
}
|
||||
index++;
|
||||
}
|
||||
kernel_source += source;
|
||||
kernel_source += "\n}\n";
|
||||
return kernel_source;
|
||||
}
|
||||
|
||||
std::string write_template(
|
||||
const std::vector<std::pair<std::string, TemplateArg>>& template_args) {
|
||||
std::ostringstream template_def;
|
||||
template_def << "<";
|
||||
int i = 0;
|
||||
for (const auto& [name, arg] : template_args) {
|
||||
if (i > 0) {
|
||||
template_def << ", ";
|
||||
}
|
||||
if (std::holds_alternative<int>(arg)) {
|
||||
template_def << std::get<int>(arg);
|
||||
} else if (std::holds_alternative<bool>(arg)) {
|
||||
template_def << std::get<bool>(arg);
|
||||
} else if (std::holds_alternative<Dtype>(arg)) {
|
||||
template_def << get_type_string(std::get<Dtype>(arg));
|
||||
}
|
||||
i++;
|
||||
}
|
||||
template_def << ">";
|
||||
return template_def.str();
|
||||
}
|
||||
|
||||
MetalKernelFunction metal_kernel(
|
||||
const std::string& name,
|
||||
const std::vector<std::string>& input_names,
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::string& source,
|
||||
const std::string& header /* = "" */,
|
||||
bool ensure_row_contiguous /* = true */,
|
||||
bool atomic_outputs /* = false */) {
|
||||
if (output_names.empty()) {
|
||||
throw std::invalid_argument(
|
||||
"[metal_kernel] Must specify at least one output.");
|
||||
}
|
||||
std::vector<CustomKernelShapeInfo> shape_infos;
|
||||
for (auto& n : input_names) {
|
||||
CustomKernelShapeInfo shape_info;
|
||||
shape_info.shape = source.find(n + "_shape") != std::string::npos;
|
||||
shape_info.strides = source.find(n + "_strides") != std::string::npos;
|
||||
shape_info.ndim = source.find(n + "_ndim") != std::string::npos;
|
||||
shape_infos.push_back(shape_info);
|
||||
}
|
||||
const std::vector<std::pair<std::string, std::string>> metal_attributes = {
|
||||
{"dispatch_quadgroups_per_threadgroup", "uint"},
|
||||
{"dispatch_simdgroups_per_threadgroup", "uint"},
|
||||
{"dispatch_threads_per_threadgroup", "uint3"},
|
||||
{"grid_origin", "uint3"},
|
||||
{"grid_size", "uint3"},
|
||||
{"quadgroup_index_in_threadgroup", "uint"},
|
||||
{"quadgroups_per_threadgroup", "uint"},
|
||||
{"simdgroup_index_in_threadgroup", "uint"},
|
||||
{"simdgroups_per_threadgroup", "uint"},
|
||||
{"thread_execution_width", "uint"},
|
||||
{"thread_index_in_quadgroup", "uint"},
|
||||
{"thread_index_in_simdgroup", "uint"},
|
||||
{"thread_index_in_threadgroup", "uint"},
|
||||
{"thread_position_in_grid", "uint3"},
|
||||
{"thread_position_in_threadgroup", "uint3"},
|
||||
{"threadgroup_position_in_grid", "uint3"},
|
||||
{"threadgroups_per_grid", "uint3"},
|
||||
{"threads_per_grid", "uint3"},
|
||||
{"threads_per_simdgroup", "uint"},
|
||||
{"threads_per_threadgroup", "uint3"},
|
||||
};
|
||||
|
||||
std::vector<std::string> attributes;
|
||||
for (const auto& [attr, dtype] : metal_attributes) {
|
||||
if (source.find(attr) != std::string::npos) {
|
||||
attributes.push_back(" " + dtype + " " + attr + " [[" + attr + "]]");
|
||||
}
|
||||
}
|
||||
|
||||
return [=,
|
||||
shape_infos = std::move(shape_infos),
|
||||
attributes = std::move(attributes)](
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<Shape>& output_shapes,
|
||||
const std::vector<Dtype>& output_dtypes,
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
const std::vector<std::pair<std::string, TemplateArg>>&
|
||||
template_args = {},
|
||||
std::optional<float> init_value = std::nullopt,
|
||||
bool verbose = false,
|
||||
StreamOrDevice s_ = {}) {
|
||||
if (inputs.size() != input_names.size()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal_kernel] Expected `inputs` to have size "
|
||||
<< input_names.size() << " but got size " << inputs.size() << "."
|
||||
<< std::endl;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (output_shapes.size() != output_names.size()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal_kernel] Expected `output_shapes` to have size "
|
||||
<< output_names.size() << " but got size " << output_shapes.size()
|
||||
<< "." << std::endl;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (output_dtypes.size() != output_names.size()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal_kernel] Expected `output_dtypes` to have size "
|
||||
<< output_names.size() << " but got size " << output_dtypes.size()
|
||||
<< "." << std::endl;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto s = to_stream(s_);
|
||||
if (s.device != Device::gpu) {
|
||||
throw std::invalid_argument("[metal_kernel] Only supports the GPU.");
|
||||
}
|
||||
|
||||
std::ostringstream func_name;
|
||||
std::string template_def = "";
|
||||
std::string hash_key = "";
|
||||
if (!template_args.empty()) {
|
||||
std::regex disallowed_chars("\\<|\\>|(, )");
|
||||
template_def = write_template(template_args);
|
||||
hash_key = std::regex_replace(template_def, disallowed_chars, "_");
|
||||
hash_key.pop_back();
|
||||
}
|
||||
func_name << "custom_kernel_" << name << hash_key;
|
||||
std::string kernel_name = func_name.str();
|
||||
|
||||
std::string kernel_source = write_signature(
|
||||
kernel_name,
|
||||
header,
|
||||
source,
|
||||
input_names,
|
||||
inputs,
|
||||
output_names,
|
||||
output_dtypes,
|
||||
template_args,
|
||||
attributes,
|
||||
shape_infos,
|
||||
atomic_outputs);
|
||||
|
||||
if (!template_args.empty()) {
|
||||
template_def = kernel_name + template_def;
|
||||
kernel_source += "\ntemplate [[host_name(\"";
|
||||
kernel_source += kernel_name;
|
||||
kernel_source += "\")]] [[kernel]] decltype(";
|
||||
kernel_source += template_def;
|
||||
kernel_source += ") ";
|
||||
kernel_source += template_def;
|
||||
kernel_source += ";\n";
|
||||
}
|
||||
|
||||
if (verbose) {
|
||||
std::cout << "Generated source code for `" << name << "`:" << std::endl
|
||||
<< "```" << std::endl
|
||||
<< kernel_source << std::endl
|
||||
<< "```" << std::endl;
|
||||
}
|
||||
|
||||
return array::make_arrays(
|
||||
std::move(output_shapes),
|
||||
std::move(output_dtypes),
|
||||
std::make_shared<CustomKernel>(
|
||||
s,
|
||||
std::move(kernel_name),
|
||||
std::move(kernel_source),
|
||||
grid,
|
||||
threadgroup,
|
||||
shape_infos,
|
||||
ensure_row_contiguous,
|
||||
init_value),
|
||||
std::move(inputs));
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fast
|
||||
|
@ -735,6 +735,41 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
)[0]
|
||||
self.assertEqual(out.item(), 2)
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
def test_custom_kernel_caching(self):
|
||||
def call_kernel(a: mx.array, source):
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="my_kernel",
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
source=source,
|
||||
)
|
||||
return kernel(
|
||||
inputs=[a],
|
||||
grid=(a.size, 1, 1),
|
||||
threadgroup=(a.size, 1, 1),
|
||||
output_shapes=[a.shape],
|
||||
output_dtypes=[a.dtype],
|
||||
stream=mx.gpu,
|
||||
)[0]
|
||||
|
||||
a = mx.random.normal(shape=(32,))
|
||||
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
out[elem] = 0.0;
|
||||
"""
|
||||
|
||||
out = call_kernel(a, source)
|
||||
self.assertTrue(mx.array_equal(out, mx.zeros_like(out)))
|
||||
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
out[elem] = 1.0;
|
||||
"""
|
||||
out = call_kernel(a, source)
|
||||
self.assertTrue(mx.array_equal(out, mx.ones_like(out)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user