mlx/mlx/backend/metal/custom_kernel.cpp
Angelos Katharopoulos 14efd9c35a Fix compilation
2025-08-20 00:31:10 -07:00

427 lines
13 KiB
C++

// Copyright © 2024 Apple Inc.
#include <iostream>
#include <regex>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/fast.h"
#include "mlx/fast_primitives.h"
#include "mlx/utils.h"
namespace mlx::core::fast {
struct CustomKernelCache {
std::unordered_map<std::string, std::string> libraries;
};
static CustomKernelCache& cache() {
static CustomKernelCache cache_;
return cache_;
};
std::string write_signature(
std::string func_name,
const std::string& header,
const std::string& source,
const std::vector<std::string>& input_names,
const std::vector<array>& inputs,
const std::vector<std::string>& output_names,
const std::vector<Dtype>& output_dtypes,
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
const std::vector<std::string>& attributes,
const std::vector<CustomKernelShapeInfo>& shape_infos,
bool atomic_outputs) {
std::string kernel_source;
kernel_source.reserve(header.size() + source.size() + 16384);
kernel_source += header;
// Auto-generate a function signature based on `template_args`
// and the dtype/shape of the arrays passed as `inputs`.
if (!template_args.empty()) {
kernel_source += "template <";
int i = 0;
for (const auto& [name, arg] : template_args) {
std::string param_type;
if (std::holds_alternative<int>(arg)) {
param_type = "int";
} else if (std::holds_alternative<bool>(arg)) {
param_type = "bool";
} else if (std::holds_alternative<Dtype>(arg)) {
param_type = "typename";
}
if (i > 0) {
kernel_source += ", ";
}
kernel_source += param_type;
kernel_source += " ";
kernel_source += name;
i++;
}
kernel_source += ">\n";
}
kernel_source += "[[kernel]] void ";
kernel_source += func_name;
kernel_source += "(\n";
int index = 0;
constexpr int max_constant_array_size = 8;
// Add inputs
for (int i = 0; i < inputs.size(); ++i) {
const auto& name = input_names[i];
const auto& arr = inputs[i];
auto dtype = get_type_string(arr.dtype());
std::string location =
arr.size() < max_constant_array_size ? "constant" : "device";
std::string ref = arr.ndim() == 0 ? "&" : "*";
kernel_source += " const ";
kernel_source += location;
kernel_source += " ";
kernel_source += dtype;
kernel_source += ref;
kernel_source += " ";
kernel_source += name;
kernel_source += " [[buffer(";
kernel_source += std::to_string(index);
kernel_source += ")]],\n";
index++;
// Add input shape, strides and ndim if present in the source
if (arr.ndim() > 0) {
if (shape_infos[i].shape) {
kernel_source +=
(" const constant int* " + name + "_shape [[buffer(" +
std::to_string(index) + ")]],\n");
index++;
}
if (shape_infos[i].strides) {
kernel_source +=
(" const constant int64_t* " + name + "_strides [[buffer(" +
std::to_string(index) + ")]],\n");
index++;
}
if (shape_infos[i].ndim) {
kernel_source +=
(" const constant int& " + name + "_ndim [[buffer(" +
std::to_string(index) + ")]],\n");
index++;
}
}
}
// Add outputs
for (int i = 0; i < output_names.size(); ++i) {
const auto& name = output_names[i];
const auto& dtype = output_dtypes[i];
kernel_source += " device ";
auto type_string = get_type_string(dtype);
if (atomic_outputs) {
kernel_source += "atomic<";
}
kernel_source += type_string;
if (atomic_outputs) {
kernel_source += ">";
}
kernel_source += "* ";
kernel_source += name;
kernel_source += " [[buffer(";
kernel_source += std::to_string(index);
kernel_source += ")]]";
if (index < inputs.size() + output_names.size() - 1 ||
attributes.size() > 0) {
kernel_source += ",\n";
} else {
kernel_source += ") {\n";
}
index++;
}
index = 0;
for (const auto& attr : attributes) {
kernel_source += attr;
if (index < attributes.size() - 1) {
kernel_source += ",\n";
} else {
kernel_source += ") {\n";
}
index++;
}
kernel_source += source;
kernel_source += "\n}\n";
return kernel_source;
}
std::string write_template(
const std::vector<std::pair<std::string, TemplateArg>>& template_args) {
std::ostringstream template_def;
template_def << "<";
int i = 0;
for (const auto& [name, arg] : template_args) {
if (i > 0) {
template_def << ", ";
}
if (std::holds_alternative<int>(arg)) {
template_def << std::get<int>(arg);
} else if (std::holds_alternative<bool>(arg)) {
template_def << std::get<bool>(arg);
} else if (std::holds_alternative<Dtype>(arg)) {
template_def << get_type_string(std::get<Dtype>(arg));
}
i++;
}
template_def << ">";
return template_def.str();
}
CustomKernelFunction metal_kernel(
const std::string& name,
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
const std::string& source,
const std::string& header /* = "" */,
bool ensure_row_contiguous /* = true */,
bool atomic_outputs /* = false */) {
if (output_names.empty()) {
throw std::invalid_argument(
"[metal_kernel] Must specify at least one output.");
}
std::vector<CustomKernelShapeInfo> shape_infos;
for (auto& n : input_names) {
CustomKernelShapeInfo shape_info;
shape_info.shape = source.find(n + "_shape") != std::string::npos;
shape_info.strides = source.find(n + "_strides") != std::string::npos;
shape_info.ndim = source.find(n + "_ndim") != std::string::npos;
shape_infos.push_back(shape_info);
}
const std::vector<std::pair<std::string, std::string>> metal_attributes = {
{"dispatch_quadgroups_per_threadgroup", "uint"},
{"dispatch_simdgroups_per_threadgroup", "uint"},
{"dispatch_threads_per_threadgroup", "uint3"},
{"grid_origin", "uint3"},
{"grid_size", "uint3"},
{"quadgroup_index_in_threadgroup", "uint"},
{"quadgroups_per_threadgroup", "uint"},
{"simdgroup_index_in_threadgroup", "uint"},
{"simdgroups_per_threadgroup", "uint"},
{"thread_execution_width", "uint"},
{"thread_index_in_quadgroup", "uint"},
{"thread_index_in_simdgroup", "uint"},
{"thread_index_in_threadgroup", "uint"},
{"thread_position_in_grid", "uint3"},
{"thread_position_in_threadgroup", "uint3"},
{"threadgroup_position_in_grid", "uint3"},
{"threadgroups_per_grid", "uint3"},
{"threads_per_grid", "uint3"},
{"threads_per_simdgroup", "uint"},
{"threads_per_threadgroup", "uint3"},
};
std::vector<std::string> attributes;
for (const auto& [attr, dtype] : metal_attributes) {
if (source.find(attr) != std::string::npos) {
attributes.push_back(" " + dtype + " " + attr + " [[" + attr + "]]");
}
}
return [=,
shape_infos = std::move(shape_infos),
attributes = std::move(attributes)](
const std::vector<array>& inputs,
const std::vector<Shape>& output_shapes,
const std::vector<Dtype>& output_dtypes,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
const std::vector<std::pair<std::string, TemplateArg>>&
template_args = {},
std::optional<float> init_value = std::nullopt,
bool verbose = false,
StreamOrDevice s_ = {}) {
if (inputs.size() != input_names.size()) {
std::ostringstream msg;
msg << "[metal_kernel] Expected `inputs` to have size "
<< input_names.size() << " but got size " << inputs.size() << "."
<< std::endl;
throw std::invalid_argument(msg.str());
}
if (output_shapes.size() != output_names.size()) {
std::ostringstream msg;
msg << "[metal_kernel] Expected `output_shapes` to have size "
<< output_names.size() << " but got size " << output_shapes.size()
<< "." << std::endl;
throw std::invalid_argument(msg.str());
}
if (output_dtypes.size() != output_names.size()) {
std::ostringstream msg;
msg << "[metal_kernel] Expected `output_dtypes` to have size "
<< output_names.size() << " but got size " << output_dtypes.size()
<< "." << std::endl;
throw std::invalid_argument(msg.str());
}
auto s = to_stream(s_);
if (s.device != Device::gpu) {
throw std::invalid_argument("[metal_kernel] Only supports the GPU.");
}
std::string kernel_name = "custom_kernel_" + name;
std::string template_def = "";
if (!template_args.empty()) {
std::regex disallowed_chars("\\<|\\>|(, )");
template_def = write_template(template_args);
auto template_hash =
std::regex_replace(template_def, disallowed_chars, "_");
template_hash.pop_back();
kernel_name += "_";
kernel_name += template_hash;
}
std::string kernel_source = write_signature(
kernel_name,
header,
source,
input_names,
inputs,
output_names,
output_dtypes,
template_args,
attributes,
shape_infos,
atomic_outputs);
if (!template_args.empty()) {
template_def = kernel_name + template_def;
kernel_source += "\ntemplate [[host_name(\"";
kernel_source += kernel_name;
kernel_source += "\")]] [[kernel]] decltype(";
kernel_source += template_def;
kernel_source += ") ";
kernel_source += template_def;
kernel_source += ";\n";
}
if (verbose) {
std::cout << "Generated source code for `" << name << "`:" << std::endl
<< "```" << std::endl
<< kernel_source << std::endl
<< "```" << std::endl;
}
return array::make_arrays(
std::move(output_shapes),
std::move(output_dtypes),
std::make_shared<CustomKernel>(
s,
std::move(kernel_name),
std::move(kernel_source),
grid,
threadgroup,
shape_infos,
ensure_row_contiguous,
init_value,
std::vector<ScalarArg>{},
false,
0),
std::move(inputs));
};
}
void CustomKernel::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& s = stream();
std::vector<array> copies;
for (auto& out : outputs) {
if (init_value_) {
copies.emplace_back(init_value_.value(), out.dtype());
fill_gpu(copies.back(), out, s);
} else {
out.set_data(allocator::malloc(out.nbytes()));
}
}
auto check_input = [&copies, &s, this](const array& x) -> const array {
bool no_copy = x.flags().row_contiguous;
if (!ensure_row_contiguous_ || no_copy) {
return x;
} else {
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
copy_gpu(x, copies.back(), CopyType::General, s);
return copies.back();
}
};
std::vector<array> checked_inputs;
for (const array& in : inputs) {
checked_inputs.push_back(check_input(in));
}
auto& d = metal::device(s.device);
{
// Clear kernels from the device library cache if needed
auto& kernel_cache = cache();
if (auto it = kernel_cache.libraries.find(name_);
it != kernel_cache.libraries.end()) {
if (it->second != source_) {
auto& d = metal::device(s.device);
d.clear_library(name_);
it->second = source_;
}
} else {
kernel_cache.libraries.emplace(name_, source_);
}
}
auto lib = d.get_library(name_, [this] { return metal::utils() + source_; });
auto kernel = d.get_kernel(name_, lib);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
int index = 0;
for (int i = 0; i < checked_inputs.size(); i++) {
const array& in = checked_inputs[i];
auto& shape_info = shape_infos_[i];
compute_encoder.set_input_array(in, index);
index++;
if (in.ndim() > 0) {
int ndim = in.ndim();
if (shape_info.shape) {
compute_encoder.set_vector_bytes(in.shape(), ndim, index);
index++;
}
if (shape_info.strides) {
compute_encoder.set_vector_bytes(in.strides(), ndim, index);
index++;
}
if (shape_info.ndim) {
compute_encoder.set_bytes(ndim, index);
index++;
}
}
}
for (auto& out : outputs) {
compute_encoder.set_output_array(out, index);
index++;
}
const auto [tx, ty, tz] = threadgroup_;
auto tg_size = tx * ty * tz;
auto max_tg_size = kernel->maxTotalThreadsPerThreadgroup();
if (tg_size > max_tg_size) {
std::ostringstream msg;
msg << "Thread group size (" << tg_size << ") is greater than "
<< " the maximum allowed threads per threadgroup (" << max_tg_size
<< ").";
throw std::invalid_argument(msg.str());
}
const auto [gx, gy, gz] = grid_;
MTL::Size group_dims =
MTL::Size(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz));
MTL::Size grid_dims = MTL::Size(gx, gy, gz);
compute_encoder.dispatch_threads(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index);
}
} // namespace mlx::core::fast