alternative solution

This commit is contained in:
Awni Hannun 2025-06-05 20:31:05 -07:00
parent ac1117b224
commit 6741d15735
11 changed files with 410 additions and 359 deletions

View File

@ -397,11 +397,11 @@ below.
std::ostringstream kname;
kname << "axpby_" << "general_" << type_to_name(out);
// Make sure the metal library is available
d.register_library("mlx_ext");
// Load the metal library
auto lib = d.get_library("mlx_ext");
// Make a kernel from this metal library
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
auto kernel = d.get_kernel(kname.str(), lib);
// Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index);

View File

@ -172,11 +172,11 @@ void Axpby::eval_gpu(
kname << (contiguous_kernel ? "contiguous_" : "general_");
kname << type_to_name(out);
// Make sure the metal library is available
d.register_library("mlx_ext");
// Load the metal library
auto lib = d.get_library("mlx_ext");
// Make a kernel from this metal library
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
auto kernel = d.get_kernel(kname.str(), lib);
// Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index);

View File

@ -677,7 +677,7 @@ void depthwise_conv_2D_gpu(
std::string hash_name = kname.str();
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
auto kernel = d.get_kernel(base_name, hash_name, func_consts);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0);

View File

@ -1,12 +1,326 @@
// Copyright © 2024 Apple Inc.
#include <iostream>
#include <regex>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/fast.h"
#include "mlx/fast_primitives.h"
#include "mlx/utils.h"
namespace mlx::core::fast {
struct CustomKernelCache {
std::unordered_map<std::string, std::string> libraries;
};
static CustomKernelCache& cache() {
static CustomKernelCache cache_;
return cache_;
};
std::string write_signature(
std::string func_name,
const std::string& header,
const std::string& source,
const std::vector<std::string>& input_names,
const std::vector<array>& inputs,
const std::vector<std::string>& output_names,
const std::vector<Dtype>& output_dtypes,
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
const std::vector<std::string>& attributes,
const std::vector<CustomKernelShapeInfo>& shape_infos,
bool atomic_outputs) {
std::string kernel_source;
kernel_source.reserve(header.size() + source.size() + 16384);
kernel_source += header;
// Auto-generate a function signature based on `template_args`
// and the dtype/shape of the arrays passed as `inputs`.
if (!template_args.empty()) {
kernel_source += "template <";
int i = 0;
for (const auto& [name, arg] : template_args) {
std::string param_type;
if (std::holds_alternative<int>(arg)) {
param_type = "int";
} else if (std::holds_alternative<bool>(arg)) {
param_type = "bool";
} else if (std::holds_alternative<Dtype>(arg)) {
param_type = "typename";
}
if (i > 0) {
kernel_source += ", ";
}
kernel_source += param_type;
kernel_source += " ";
kernel_source += name;
i++;
}
kernel_source += ">\n";
}
kernel_source += "[[kernel]] void ";
kernel_source += func_name;
kernel_source += "(\n";
int index = 0;
constexpr int max_constant_array_size = 8;
// Add inputs
for (int i = 0; i < inputs.size(); ++i) {
const auto& name = input_names[i];
const auto& arr = inputs[i];
auto dtype = get_type_string(arr.dtype());
std::string location =
arr.size() < max_constant_array_size ? "constant" : "device";
std::string ref = arr.ndim() == 0 ? "&" : "*";
kernel_source += " const ";
kernel_source += location;
kernel_source += " ";
kernel_source += dtype;
kernel_source += ref;
kernel_source += " ";
kernel_source += name;
kernel_source += " [[buffer(";
kernel_source += std::to_string(index);
kernel_source += ")]],\n";
index++;
// Add input shape, strides and ndim if present in the source
if (arr.ndim() > 0) {
if (shape_infos[i].shape) {
kernel_source +=
(" const constant int* " + name + "_shape [[buffer(" +
std::to_string(index) + ")]],\n");
index++;
}
if (shape_infos[i].strides) {
kernel_source +=
(" const constant int64_t* " + name + "_strides [[buffer(" +
std::to_string(index) + ")]],\n");
index++;
}
if (shape_infos[i].ndim) {
kernel_source +=
(" const constant int& " + name + "_ndim [[buffer(" +
std::to_string(index) + ")]],\n");
index++;
}
}
}
// Add outputs
for (int i = 0; i < output_names.size(); ++i) {
const auto& name = output_names[i];
const auto& dtype = output_dtypes[i];
kernel_source += " device ";
auto type_string = get_type_string(dtype);
if (atomic_outputs) {
kernel_source += "atomic<";
}
kernel_source += type_string;
if (atomic_outputs) {
kernel_source += ">";
}
kernel_source += "* ";
kernel_source += name;
kernel_source += " [[buffer(";
kernel_source += std::to_string(index);
kernel_source += ")]]";
if (index < inputs.size() + output_names.size() - 1 ||
attributes.size() > 0) {
kernel_source += ",\n";
} else {
kernel_source += ") {\n";
}
index++;
}
index = 0;
for (const auto& attr : attributes) {
kernel_source += attr;
if (index < attributes.size() - 1) {
kernel_source += ",\n";
} else {
kernel_source += ") {\n";
}
index++;
}
kernel_source += source;
kernel_source += "\n}\n";
return kernel_source;
}
std::string write_template(
const std::vector<std::pair<std::string, TemplateArg>>& template_args) {
std::ostringstream template_def;
template_def << "<";
int i = 0;
for (const auto& [name, arg] : template_args) {
if (i > 0) {
template_def << ", ";
}
if (std::holds_alternative<int>(arg)) {
template_def << std::get<int>(arg);
} else if (std::holds_alternative<bool>(arg)) {
template_def << std::get<bool>(arg);
} else if (std::holds_alternative<Dtype>(arg)) {
template_def << get_type_string(std::get<Dtype>(arg));
}
i++;
}
template_def << ">";
return template_def.str();
}
MetalKernelFunction metal_kernel(
const std::string& name,
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
const std::string& source,
const std::string& header /* = "" */,
bool ensure_row_contiguous /* = true */,
bool atomic_outputs /* = false */) {
if (output_names.empty()) {
throw std::invalid_argument(
"[metal_kernel] Must specify at least one output.");
}
std::vector<CustomKernelShapeInfo> shape_infos;
for (auto& n : input_names) {
CustomKernelShapeInfo shape_info;
shape_info.shape = source.find(n + "_shape") != std::string::npos;
shape_info.strides = source.find(n + "_strides") != std::string::npos;
shape_info.ndim = source.find(n + "_ndim") != std::string::npos;
shape_infos.push_back(shape_info);
}
const std::vector<std::pair<std::string, std::string>> metal_attributes = {
{"dispatch_quadgroups_per_threadgroup", "uint"},
{"dispatch_simdgroups_per_threadgroup", "uint"},
{"dispatch_threads_per_threadgroup", "uint3"},
{"grid_origin", "uint3"},
{"grid_size", "uint3"},
{"quadgroup_index_in_threadgroup", "uint"},
{"quadgroups_per_threadgroup", "uint"},
{"simdgroup_index_in_threadgroup", "uint"},
{"simdgroups_per_threadgroup", "uint"},
{"thread_execution_width", "uint"},
{"thread_index_in_quadgroup", "uint"},
{"thread_index_in_simdgroup", "uint"},
{"thread_index_in_threadgroup", "uint"},
{"thread_position_in_grid", "uint3"},
{"thread_position_in_threadgroup", "uint3"},
{"threadgroup_position_in_grid", "uint3"},
{"threadgroups_per_grid", "uint3"},
{"threads_per_grid", "uint3"},
{"threads_per_simdgroup", "uint"},
{"threads_per_threadgroup", "uint3"},
};
std::vector<std::string> attributes;
for (const auto& [attr, dtype] : metal_attributes) {
if (source.find(attr) != std::string::npos) {
attributes.push_back(" " + dtype + " " + attr + " [[" + attr + "]]");
}
}
return [=,
shape_infos = std::move(shape_infos),
attributes = std::move(attributes)](
const std::vector<array>& inputs,
const std::vector<Shape>& output_shapes,
const std::vector<Dtype>& output_dtypes,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
const std::vector<std::pair<std::string, TemplateArg>>&
template_args = {},
std::optional<float> init_value = std::nullopt,
bool verbose = false,
StreamOrDevice s_ = {}) {
if (inputs.size() != input_names.size()) {
std::ostringstream msg;
msg << "[metal_kernel] Expected `inputs` to have size "
<< input_names.size() << " but got size " << inputs.size() << "."
<< std::endl;
throw std::invalid_argument(msg.str());
}
if (output_shapes.size() != output_names.size()) {
std::ostringstream msg;
msg << "[metal_kernel] Expected `output_shapes` to have size "
<< output_names.size() << " but got size " << output_shapes.size()
<< "." << std::endl;
throw std::invalid_argument(msg.str());
}
if (output_dtypes.size() != output_names.size()) {
std::ostringstream msg;
msg << "[metal_kernel] Expected `output_dtypes` to have size "
<< output_names.size() << " but got size " << output_dtypes.size()
<< "." << std::endl;
throw std::invalid_argument(msg.str());
}
auto s = to_stream(s_);
if (s.device != Device::gpu) {
throw std::invalid_argument("[metal_kernel] Only supports the GPU.");
}
std::string kernel_name = "custom_kernel_" + name;
std::string template_def = "";
if (!template_args.empty()) {
std::regex disallowed_chars("\\<|\\>|(, )");
template_def = write_template(template_args);
auto template_hash =
std::regex_replace(template_def, disallowed_chars, "_");
template_hash.pop_back();
kernel_name += "_";
kernel_name += template_hash;
}
std::string kernel_source = write_signature(
kernel_name,
header,
source,
input_names,
inputs,
output_names,
output_dtypes,
template_args,
attributes,
shape_infos,
atomic_outputs);
if (!template_args.empty()) {
template_def = kernel_name + template_def;
kernel_source += "\ntemplate [[host_name(\"";
kernel_source += kernel_name;
kernel_source += "\")]] [[kernel]] decltype(";
kernel_source += template_def;
kernel_source += ") ";
kernel_source += template_def;
kernel_source += ";\n";
}
if (verbose) {
std::cout << "Generated source code for `" << name << "`:" << std::endl
<< "```" << std::endl
<< kernel_source << std::endl
<< "```" << std::endl;
}
return array::make_arrays(
std::move(output_shapes),
std::move(output_dtypes),
std::make_shared<CustomKernel>(
s,
std::move(kernel_name),
std::move(kernel_source),
grid,
threadgroup,
shape_infos,
ensure_row_contiguous,
init_value),
std::move(inputs));
};
}
void CustomKernel::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
@ -39,9 +353,23 @@ void CustomKernel::eval_gpu(
}
auto& d = metal::device(s.device);
const auto& lib_name = name_;
auto lib =
d.get_library(lib_name, [this] { return metal::utils() + source_; });
{
// Clear kernels from the device library cache if needed
auto& kernel_cache = cache();
if (auto it = kernel_cache.libraries.find(name_);
it != kernel_cache.libraries.end()) {
if (it->second != source_) {
auto& d = metal::device(s.device);
d.clear_library(name_);
it->second = source_;
}
} else {
kernel_cache.libraries.emplace(name_, source_);
}
}
auto lib = d.get_library(name_, [this] { return metal::utils() + source_; });
auto kernel = d.get_kernel(name_, lib);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);

View File

@ -295,7 +295,7 @@ void CommandEncoder::barrier() {
Device::Device() {
auto pool = new_scoped_memory_pool();
device_ = load_device();
library_map_ = {{"mlx", load_default_library(device_)}};
default_library_ = load_default_library(device_);
arch_ = std::string(device_->architecture()->name()->utf8String());
auto arch = arch_.back();
switch (arch) {
@ -326,11 +326,11 @@ Device::Device() {
Device::~Device() {
auto pool = new_scoped_memory_pool();
for (auto& k : kernel_map_) {
k.second->release();
}
for (auto& l : library_map_) {
l.second->release();
for (auto& [l, kernel_map] : library_kernels_) {
l->release();
for (auto& [_, k] : kernel_map) {
k->release();
}
}
stream_map_.clear();
device_->release();
@ -474,13 +474,24 @@ CommandEncoder& Device::get_command_encoder(int index) {
return *stream.encoder;
}
void Device::register_library(
const std::string& lib_name,
const std::string& lib_path) {
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
auto new_lib = load_library(device_, lib_name, lib_path.c_str());
library_map_.insert({lib_name, new_lib});
MTL::Library* Device::get_library(
const std::string& name,
const std::string& path /* = "" */) {
{
std::shared_lock rlock(library_mtx_);
if (auto it = library_map_.find(name); it != library_map_.end()) {
return it->second;
}
}
std::unique_lock wlock(library_mtx_);
if (auto it = library_map_.find(name); it != library_map_.end()) {
return it->second;
}
auto new_lib = load_library(device_, name, path.c_str());
library_map_.insert({name, new_lib});
return new_lib;
}
MTL::Library* Device::build_library_(const std::string& source_string) {
@ -649,6 +660,19 @@ MTL::Library* Device::get_library(
return mtl_lib;
}
void Device::clear_library(const std::string& name) {
std::unique_lock wlock(library_mtx_);
if (auto it = library_map_.find(name); it != library_map_.end()) {
auto kernel_map_it = library_kernels_.find(it->second);
for (auto& [_, kernel] : kernel_map_it->second) {
kernel->release();
}
library_kernels_.erase(kernel_map_it);
it->second->release();
library_map_.erase(it);
}
}
MTL::LinkedFunctions* Device::get_linked_functions_(
const std::vector<MTL::Function*>& funcs) {
if (funcs.empty()) {
@ -679,6 +703,7 @@ MTL::ComputePipelineState* Device::get_kernel_(
std::unique_lock wlock(kernel_mtx_);
// Try loading again to avoid loading twice
auto& kernel_map_ = library_kernels_[mtl_lib];
if (auto it = kernel_map_.find(hash_name); it != kernel_map_.end()) {
return it->second;
}
@ -713,6 +738,7 @@ MTL::ComputePipelineState* Device::get_kernel(
std::shared_lock lock(kernel_mtx_);
// Look for cached kernel
auto& kernel_map_ = library_kernels_[mtl_lib];
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
return it->second;
}
@ -722,23 +748,11 @@ MTL::ComputePipelineState* Device::get_kernel(
MTL::ComputePipelineState* Device::get_kernel(
const std::string& base_name,
const std::string& lib_name /* = "mlx" */,
const std::string& hash_name /* = "" */,
const MTLFCList& func_consts /* = {} */,
const std::vector<MTL::Function*>& linked_functions /* = {} */) {
const auto& kname = hash_name.size() == 0 ? base_name : hash_name;
{
// Multiple readers allowed
std::shared_lock lock(kernel_mtx_);
// Look for cached kernel
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
return it->second;
}
}
// Search for cached metal lib
MTL::Library* mtl_lib = get_library_(lib_name);
return get_kernel_(base_name, mtl_lib, kname, func_consts, linked_functions);
return get_kernel(
base_name, default_library_, hash_name, func_consts, linked_functions);
}
void Device::set_residency_set(const MTL::ResidencySet* residency_set) {

View File

@ -187,14 +187,16 @@ class Device {
CommandEncoder& get_command_encoder(int index);
void end_encoding(int index);
void register_library(
const std::string& lib_name,
const std::string& lib_path = "");
MTL::Library* get_library(
const std::string& name,
const std::string& path = "");
MTL::Library* get_library(
const std::string& name,
const std::function<std::string(void)>& builder);
void clear_library(const std::string& name);
MTL::ComputePipelineState* get_kernel(
const std::string& base_name,
MTL::Library* mtl_lib,
@ -204,7 +206,6 @@ class Device {
MTL::ComputePipelineState* get_kernel(
const std::string& base_name,
const std::string& lib_name = "mlx",
const std::string& hash_name = "",
const MTLFCList& func_consts = {},
const std::vector<MTL::Function*>& linked_functions = {});
@ -258,10 +259,13 @@ class Device {
std::unordered_map<int32_t, DeviceStream> stream_map_;
std::shared_mutex kernel_mtx_;
std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
std::shared_mutex library_mtx_;
std::unordered_map<std::string, MTL::Library*> library_map_;
MTL::Library* default_library_;
std::unordered_map<
MTL::Library*,
std::unordered_map<std::string, MTL::ComputePipelineState*>>
library_kernels_;
const MTL::ResidencySet* residency_set_{nullptr};
std::string arch_;
int max_ops_per_buffer_;

View File

@ -146,7 +146,7 @@ MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
int,
int,
int) {
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
return d.get_kernel(kernel_name, hash_name, func_consts);
}
MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
@ -207,7 +207,7 @@ MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
int,
int,
bool) {
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
return d.get_kernel(kernel_name, hash_name, func_consts);
}
MTL::ComputePipelineState* get_gemv_masked_kernel(
@ -259,7 +259,7 @@ MTL::ComputePipelineState* get_fft_kernel(
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const std::string&) {
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
return d.get_kernel(kernel_name, hash_name, func_consts);
}
MTL::ComputePipelineState* get_quantized_kernel(
@ -283,7 +283,7 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
int,
int,
bool) {
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
return d.get_kernel(kernel_name, hash_name, func_consts);
}
} // namespace mlx::core

View File

@ -172,7 +172,7 @@ void RMSNormVJP::eval_gpu(
auto& compute_encoder = d.get_command_encoder(s.index);
{
auto kernel = d.get_kernel(op_name, "mlx", hash_name, func_consts);
auto kernel = d.get_kernel(op_name, hash_name, func_consts);
MTL::Size grid_dims, group_dims;
if (axis_size <= looped_limit) {
@ -387,7 +387,7 @@ void LayerNormVJP::eval_gpu(
};
{
auto kernel = d.get_kernel(op_name, "mlx", hash_name, func_consts);
auto kernel = d.get_kernel(op_name, hash_name, func_consts);
MTL::Size grid_dims, group_dims;
if (axis_size <= looped_limit) {

View File

@ -73,7 +73,7 @@ void sdpa_full_self_attention_metal(
std::string hash_name = kname.str();
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
auto kernel = d.get_kernel(base_name, hash_name, func_consts);
compute_encoder.set_compute_pipeline_state(kernel);
const int NQ = (qL + bq - 1) / bq;
@ -180,7 +180,7 @@ void sdpa_vector(
// Get the kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname, "mlx", hash_name, func_consts);
auto kernel = d.get_kernel(kname, hash_name, func_consts);
compute_encoder.set_compute_pipeline_state(kernel);
// Set its arguments
@ -281,7 +281,7 @@ void sdpa_vector_2pass(
// Get the kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname, "mlx", hash_name, func_consts);
auto kernel = d.get_kernel(kname, hash_name, func_consts);
compute_encoder.set_compute_pipeline_state(kernel);

View File

@ -2,6 +2,7 @@
#include "mlx/primitives.h"
#include "mlx/distributed/primitives.h"
#include "mlx/fast.h"
#include "mlx/fast_primitives.h"
#define NO_GPU_MULTI(func) \
@ -155,6 +156,18 @@ NO_GPU_USE_FALLBACK(RoPE)
NO_GPU(ScaledDotProductAttention)
NO_GPU_MULTI(AffineQuantize)
NO_GPU_MULTI(CustomKernel)
MetalKernelFunction metal_kernel(
const std::string&,
const std::vector<std::string>&,
const std::vector<std::string>&,
const std::string&,
const std::string&,
bool ensure_row_contiguous,
bool atomic_outputs) {
throw std::runtime_error("[metal_kernel] No GPU back-end.");
}
} // namespace fast
namespace distributed {

View File

@ -1,11 +1,7 @@
// Copyright © 2023-2024 Apple Inc.
#include <cassert>
#include <chrono>
#include <iostream>
#include <numeric>
#include <regex>
#include "mlx/backend/common/compiled.h"
#include "mlx/fast.h"
#include "mlx/fast_primitives.h"
#include "mlx/ops.h"
@ -1030,308 +1026,4 @@ std::vector<Shape> AffineQuantize::output_shapes(
}
}
std::string write_signature(
std::string func_name,
const std::string& header,
const std::string& source,
const std::vector<std::string>& input_names,
const std::vector<array>& inputs,
const std::vector<std::string>& output_names,
const std::vector<Dtype>& output_dtypes,
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
const std::vector<std::string>& attributes,
const std::vector<CustomKernelShapeInfo>& shape_infos,
bool atomic_outputs) {
std::string kernel_source;
kernel_source.reserve(header.size() + source.size() + 16384);
kernel_source += header;
// Auto-generate a function signature based on `template_args`
// and the dtype/shape of the arrays passed as `inputs`.
if (!template_args.empty()) {
kernel_source += "template <";
int i = 0;
for (const auto& [name, arg] : template_args) {
std::string param_type;
if (std::holds_alternative<int>(arg)) {
param_type = "int";
} else if (std::holds_alternative<bool>(arg)) {
param_type = "bool";
} else if (std::holds_alternative<Dtype>(arg)) {
param_type = "typename";
}
if (i > 0) {
kernel_source += ", ";
}
kernel_source += param_type;
kernel_source += " ";
kernel_source += name;
i++;
}
kernel_source += ">\n";
}
kernel_source += "[[kernel]] void ";
kernel_source += func_name;
kernel_source += "(\n";
int index = 0;
constexpr int max_constant_array_size = 8;
// Add inputs
for (int i = 0; i < inputs.size(); ++i) {
const auto& name = input_names[i];
const auto& arr = inputs[i];
auto dtype = get_type_string(arr.dtype());
std::string location =
arr.size() < max_constant_array_size ? "constant" : "device";
std::string ref = arr.ndim() == 0 ? "&" : "*";
kernel_source += " const ";
kernel_source += location;
kernel_source += " ";
kernel_source += dtype;
kernel_source += ref;
kernel_source += " ";
kernel_source += name;
kernel_source += " [[buffer(";
kernel_source += std::to_string(index);
kernel_source += ")]],\n";
index++;
// Add input shape, strides and ndim if present in the source
if (arr.ndim() > 0) {
if (shape_infos[i].shape) {
kernel_source +=
(" const constant int* " + name + "_shape [[buffer(" +
std::to_string(index) + ")]],\n");
index++;
}
if (shape_infos[i].strides) {
kernel_source +=
(" const constant int64_t* " + name + "_strides [[buffer(" +
std::to_string(index) + ")]],\n");
index++;
}
if (shape_infos[i].ndim) {
kernel_source +=
(" const constant int& " + name + "_ndim [[buffer(" +
std::to_string(index) + ")]],\n");
index++;
}
}
}
// Add outputs
for (int i = 0; i < output_names.size(); ++i) {
const auto& name = output_names[i];
const auto& dtype = output_dtypes[i];
kernel_source += " device ";
auto type_string = get_type_string(dtype);
if (atomic_outputs) {
kernel_source += "atomic<";
}
kernel_source += type_string;
if (atomic_outputs) {
kernel_source += ">";
}
kernel_source += "* ";
kernel_source += name;
kernel_source += " [[buffer(";
kernel_source += std::to_string(index);
kernel_source += ")]]";
if (index < inputs.size() + output_names.size() - 1 ||
attributes.size() > 0) {
kernel_source += ",\n";
} else {
kernel_source += ") {\n";
}
index++;
}
index = 0;
for (const auto& attr : attributes) {
kernel_source += attr;
if (index < attributes.size() - 1) {
kernel_source += ",\n";
} else {
kernel_source += ") {\n";
}
index++;
}
kernel_source += source;
kernel_source += "\n}\n";
return kernel_source;
}
std::string write_template(
const std::vector<std::pair<std::string, TemplateArg>>& template_args) {
std::ostringstream template_def;
template_def << "<";
int i = 0;
for (const auto& [name, arg] : template_args) {
if (i > 0) {
template_def << ", ";
}
if (std::holds_alternative<int>(arg)) {
template_def << std::get<int>(arg);
} else if (std::holds_alternative<bool>(arg)) {
template_def << std::get<bool>(arg);
} else if (std::holds_alternative<Dtype>(arg)) {
template_def << get_type_string(std::get<Dtype>(arg));
}
i++;
}
template_def << ">";
return template_def.str();
}
MetalKernelFunction metal_kernel(
const std::string& name,
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
const std::string& source,
const std::string& header /* = "" */,
bool ensure_row_contiguous /* = true */,
bool atomic_outputs /* = false */) {
if (output_names.empty()) {
throw std::invalid_argument(
"[metal_kernel] Must specify at least one output.");
}
std::vector<CustomKernelShapeInfo> shape_infos;
for (auto& n : input_names) {
CustomKernelShapeInfo shape_info;
shape_info.shape = source.find(n + "_shape") != std::string::npos;
shape_info.strides = source.find(n + "_strides") != std::string::npos;
shape_info.ndim = source.find(n + "_ndim") != std::string::npos;
shape_infos.push_back(shape_info);
}
const std::vector<std::pair<std::string, std::string>> metal_attributes = {
{"dispatch_quadgroups_per_threadgroup", "uint"},
{"dispatch_simdgroups_per_threadgroup", "uint"},
{"dispatch_threads_per_threadgroup", "uint3"},
{"grid_origin", "uint3"},
{"grid_size", "uint3"},
{"quadgroup_index_in_threadgroup", "uint"},
{"quadgroups_per_threadgroup", "uint"},
{"simdgroup_index_in_threadgroup", "uint"},
{"simdgroups_per_threadgroup", "uint"},
{"thread_execution_width", "uint"},
{"thread_index_in_quadgroup", "uint"},
{"thread_index_in_simdgroup", "uint"},
{"thread_index_in_threadgroup", "uint"},
{"thread_position_in_grid", "uint3"},
{"thread_position_in_threadgroup", "uint3"},
{"threadgroup_position_in_grid", "uint3"},
{"threadgroups_per_grid", "uint3"},
{"threads_per_grid", "uint3"},
{"threads_per_simdgroup", "uint"},
{"threads_per_threadgroup", "uint3"},
};
std::vector<std::string> attributes;
for (const auto& [attr, dtype] : metal_attributes) {
if (source.find(attr) != std::string::npos) {
attributes.push_back(" " + dtype + " " + attr + " [[" + attr + "]]");
}
}
auto now = std::chrono::system_clock::now();
int64_t timestamp = std::chrono::duration_cast<std::chrono::milliseconds>(
now.time_since_epoch())
.count();
return [=,
shape_infos = std::move(shape_infos),
attributes = std::move(attributes)](
const std::vector<array>& inputs,
const std::vector<Shape>& output_shapes,
const std::vector<Dtype>& output_dtypes,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
const std::vector<std::pair<std::string, TemplateArg>>&
template_args = {},
std::optional<float> init_value = std::nullopt,
bool verbose = false,
StreamOrDevice s_ = {}) {
if (inputs.size() != input_names.size()) {
std::ostringstream msg;
msg << "[metal_kernel] Expected `inputs` to have size "
<< input_names.size() << " but got size " << inputs.size() << "."
<< std::endl;
throw std::invalid_argument(msg.str());
}
if (output_shapes.size() != output_names.size()) {
std::ostringstream msg;
msg << "[metal_kernel] Expected `output_shapes` to have size "
<< output_names.size() << " but got size " << output_shapes.size()
<< "." << std::endl;
throw std::invalid_argument(msg.str());
}
if (output_dtypes.size() != output_names.size()) {
std::ostringstream msg;
msg << "[metal_kernel] Expected `output_dtypes` to have size "
<< output_names.size() << " but got size " << output_dtypes.size()
<< "." << std::endl;
throw std::invalid_argument(msg.str());
}
auto s = to_stream(s_);
if (s.device != Device::gpu) {
throw std::invalid_argument("[metal_kernel] Only supports the GPU.");
}
std::ostringstream func_name;
std::string template_def = "";
std::string template_hash = "";
if (!template_args.empty()) {
std::regex disallowed_chars("\\<|\\>|(, )");
template_def = write_template(template_args);
template_hash = std::regex_replace(template_def, disallowed_chars, "_");
template_hash.pop_back();
}
func_name << "custom_kernel_" << name << "_" << template_hash << "_"
<< timestamp;
std::string kernel_name = func_name.str();
std::string kernel_source = write_signature(
kernel_name,
header,
source,
input_names,
inputs,
output_names,
output_dtypes,
template_args,
attributes,
shape_infos,
atomic_outputs);
if (!template_args.empty()) {
template_def = kernel_name + template_def;
kernel_source += "\ntemplate [[host_name(\"";
kernel_source += kernel_name;
kernel_source += "\")]] [[kernel]] decltype(";
kernel_source += template_def;
kernel_source += ") ";
kernel_source += template_def;
kernel_source += ";\n";
}
if (verbose) {
std::cout << "Generated source code for `" << name << "`:" << std::endl
<< "```" << std::endl
<< kernel_source << std::endl
<< "```" << std::endl;
}
return array::make_arrays(
std::move(output_shapes),
std::move(output_dtypes),
std::make_shared<CustomKernel>(
s,
std::move(kernel_name),
std::move(kernel_source),
grid,
threadgroup,
shape_infos,
ensure_row_contiguous,
init_value),
std::move(inputs));
};
}
} // namespace mlx::core::fast