diff --git a/docs/src/dev/extensions.rst b/docs/src/dev/extensions.rst index 2aef28f99..03f1c2163 100644 --- a/docs/src/dev/extensions.rst +++ b/docs/src/dev/extensions.rst @@ -397,11 +397,11 @@ below. std::ostringstream kname; kname << "axpby_" << "general_" << type_to_name(out); - // Make sure the metal library is available - d.register_library("mlx_ext"); + // Load the metal library + auto lib = d.get_library("mlx_ext"); // Make a kernel from this metal library - auto kernel = d.get_kernel(kname.str(), "mlx_ext"); + auto kernel = d.get_kernel(kname.str(), lib); // Prepare to encode kernel auto& compute_encoder = d.get_command_encoder(s.index); diff --git a/examples/extensions/axpby/axpby.cpp b/examples/extensions/axpby/axpby.cpp index 291246617..9ba933483 100644 --- a/examples/extensions/axpby/axpby.cpp +++ b/examples/extensions/axpby/axpby.cpp @@ -172,11 +172,11 @@ void Axpby::eval_gpu( kname << (contiguous_kernel ? "contiguous_" : "general_"); kname << type_to_name(out); - // Make sure the metal library is available - d.register_library("mlx_ext"); + // Load the metal library + auto lib = d.get_library("mlx_ext"); // Make a kernel from this metal library - auto kernel = d.get_kernel(kname.str(), "mlx_ext"); + auto kernel = d.get_kernel(kname.str(), lib); // Prepare to encode kernel auto& compute_encoder = d.get_command_encoder(s.index); diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index 6b4b70d47..593b79384 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -677,7 +677,7 @@ void depthwise_conv_2D_gpu( std::string hash_name = kname.str(); auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts); + auto kernel = d.get_kernel(base_name, hash_name, func_consts); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(in, 0); diff --git a/mlx/backend/metal/custom_kernel.cpp b/mlx/backend/metal/custom_kernel.cpp index 0240126b1..161503a0e 100644 --- a/mlx/backend/metal/custom_kernel.cpp +++ b/mlx/backend/metal/custom_kernel.cpp @@ -1,12 +1,326 @@ // Copyright © 2024 Apple Inc. +#include +#include + +#include "mlx/backend/common/compiled.h" #include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/utils.h" +#include "mlx/fast.h" #include "mlx/fast_primitives.h" +#include "mlx/utils.h" namespace mlx::core::fast { +struct CustomKernelCache { + std::unordered_map libraries; +}; + +static CustomKernelCache& cache() { + static CustomKernelCache cache_; + return cache_; +}; + +std::string write_signature( + std::string func_name, + const std::string& header, + const std::string& source, + const std::vector& input_names, + const std::vector& inputs, + const std::vector& output_names, + const std::vector& output_dtypes, + const std::vector>& template_args, + const std::vector& attributes, + const std::vector& shape_infos, + bool atomic_outputs) { + std::string kernel_source; + kernel_source.reserve(header.size() + source.size() + 16384); + kernel_source += header; + // Auto-generate a function signature based on `template_args` + // and the dtype/shape of the arrays passed as `inputs`. + if (!template_args.empty()) { + kernel_source += "template <"; + int i = 0; + for (const auto& [name, arg] : template_args) { + std::string param_type; + if (std::holds_alternative(arg)) { + param_type = "int"; + } else if (std::holds_alternative(arg)) { + param_type = "bool"; + } else if (std::holds_alternative(arg)) { + param_type = "typename"; + } + if (i > 0) { + kernel_source += ", "; + } + kernel_source += param_type; + kernel_source += " "; + kernel_source += name; + i++; + } + kernel_source += ">\n"; + } + kernel_source += "[[kernel]] void "; + kernel_source += func_name; + kernel_source += "(\n"; + + int index = 0; + constexpr int max_constant_array_size = 8; + // Add inputs + for (int i = 0; i < inputs.size(); ++i) { + const auto& name = input_names[i]; + const auto& arr = inputs[i]; + auto dtype = get_type_string(arr.dtype()); + std::string location = + arr.size() < max_constant_array_size ? "constant" : "device"; + std::string ref = arr.ndim() == 0 ? "&" : "*"; + kernel_source += " const "; + kernel_source += location; + kernel_source += " "; + kernel_source += dtype; + kernel_source += ref; + kernel_source += " "; + kernel_source += name; + kernel_source += " [[buffer("; + kernel_source += std::to_string(index); + kernel_source += ")]],\n"; + index++; + // Add input shape, strides and ndim if present in the source + if (arr.ndim() > 0) { + if (shape_infos[i].shape) { + kernel_source += + (" const constant int* " + name + "_shape [[buffer(" + + std::to_string(index) + ")]],\n"); + index++; + } + if (shape_infos[i].strides) { + kernel_source += + (" const constant int64_t* " + name + "_strides [[buffer(" + + std::to_string(index) + ")]],\n"); + index++; + } + if (shape_infos[i].ndim) { + kernel_source += + (" const constant int& " + name + "_ndim [[buffer(" + + std::to_string(index) + ")]],\n"); + index++; + } + } + } + // Add outputs + for (int i = 0; i < output_names.size(); ++i) { + const auto& name = output_names[i]; + const auto& dtype = output_dtypes[i]; + kernel_source += " device "; + auto type_string = get_type_string(dtype); + if (atomic_outputs) { + kernel_source += "atomic<"; + } + kernel_source += type_string; + if (atomic_outputs) { + kernel_source += ">"; + } + kernel_source += "* "; + kernel_source += name; + kernel_source += " [[buffer("; + kernel_source += std::to_string(index); + kernel_source += ")]]"; + if (index < inputs.size() + output_names.size() - 1 || + attributes.size() > 0) { + kernel_source += ",\n"; + } else { + kernel_source += ") {\n"; + } + index++; + } + + index = 0; + for (const auto& attr : attributes) { + kernel_source += attr; + if (index < attributes.size() - 1) { + kernel_source += ",\n"; + } else { + kernel_source += ") {\n"; + } + index++; + } + kernel_source += source; + kernel_source += "\n}\n"; + return kernel_source; +} + +std::string write_template( + const std::vector>& template_args) { + std::ostringstream template_def; + template_def << "<"; + int i = 0; + for (const auto& [name, arg] : template_args) { + if (i > 0) { + template_def << ", "; + } + if (std::holds_alternative(arg)) { + template_def << std::get(arg); + } else if (std::holds_alternative(arg)) { + template_def << std::get(arg); + } else if (std::holds_alternative(arg)) { + template_def << get_type_string(std::get(arg)); + } + i++; + } + template_def << ">"; + return template_def.str(); +} + +MetalKernelFunction metal_kernel( + const std::string& name, + const std::vector& input_names, + const std::vector& output_names, + const std::string& source, + const std::string& header /* = "" */, + bool ensure_row_contiguous /* = true */, + bool atomic_outputs /* = false */) { + if (output_names.empty()) { + throw std::invalid_argument( + "[metal_kernel] Must specify at least one output."); + } + std::vector shape_infos; + for (auto& n : input_names) { + CustomKernelShapeInfo shape_info; + shape_info.shape = source.find(n + "_shape") != std::string::npos; + shape_info.strides = source.find(n + "_strides") != std::string::npos; + shape_info.ndim = source.find(n + "_ndim") != std::string::npos; + shape_infos.push_back(shape_info); + } + const std::vector> metal_attributes = { + {"dispatch_quadgroups_per_threadgroup", "uint"}, + {"dispatch_simdgroups_per_threadgroup", "uint"}, + {"dispatch_threads_per_threadgroup", "uint3"}, + {"grid_origin", "uint3"}, + {"grid_size", "uint3"}, + {"quadgroup_index_in_threadgroup", "uint"}, + {"quadgroups_per_threadgroup", "uint"}, + {"simdgroup_index_in_threadgroup", "uint"}, + {"simdgroups_per_threadgroup", "uint"}, + {"thread_execution_width", "uint"}, + {"thread_index_in_quadgroup", "uint"}, + {"thread_index_in_simdgroup", "uint"}, + {"thread_index_in_threadgroup", "uint"}, + {"thread_position_in_grid", "uint3"}, + {"thread_position_in_threadgroup", "uint3"}, + {"threadgroup_position_in_grid", "uint3"}, + {"threadgroups_per_grid", "uint3"}, + {"threads_per_grid", "uint3"}, + {"threads_per_simdgroup", "uint"}, + {"threads_per_threadgroup", "uint3"}, + }; + + std::vector attributes; + for (const auto& [attr, dtype] : metal_attributes) { + if (source.find(attr) != std::string::npos) { + attributes.push_back(" " + dtype + " " + attr + " [[" + attr + "]]"); + } + } + + return [=, + shape_infos = std::move(shape_infos), + attributes = std::move(attributes)]( + const std::vector& inputs, + const std::vector& output_shapes, + const std::vector& output_dtypes, + std::tuple grid, + std::tuple threadgroup, + const std::vector>& + template_args = {}, + std::optional init_value = std::nullopt, + bool verbose = false, + StreamOrDevice s_ = {}) { + if (inputs.size() != input_names.size()) { + std::ostringstream msg; + msg << "[metal_kernel] Expected `inputs` to have size " + << input_names.size() << " but got size " << inputs.size() << "." + << std::endl; + throw std::invalid_argument(msg.str()); + } + if (output_shapes.size() != output_names.size()) { + std::ostringstream msg; + msg << "[metal_kernel] Expected `output_shapes` to have size " + << output_names.size() << " but got size " << output_shapes.size() + << "." << std::endl; + throw std::invalid_argument(msg.str()); + } + if (output_dtypes.size() != output_names.size()) { + std::ostringstream msg; + msg << "[metal_kernel] Expected `output_dtypes` to have size " + << output_names.size() << " but got size " << output_dtypes.size() + << "." << std::endl; + throw std::invalid_argument(msg.str()); + } + + auto s = to_stream(s_); + if (s.device != Device::gpu) { + throw std::invalid_argument("[metal_kernel] Only supports the GPU."); + } + + std::string kernel_name = "custom_kernel_" + name; + std::string template_def = ""; + if (!template_args.empty()) { + std::regex disallowed_chars("\\<|\\>|(, )"); + template_def = write_template(template_args); + auto template_hash = + std::regex_replace(template_def, disallowed_chars, "_"); + template_hash.pop_back(); + kernel_name += "_"; + kernel_name += template_hash; + } + + std::string kernel_source = write_signature( + kernel_name, + header, + source, + input_names, + inputs, + output_names, + output_dtypes, + template_args, + attributes, + shape_infos, + atomic_outputs); + + if (!template_args.empty()) { + template_def = kernel_name + template_def; + kernel_source += "\ntemplate [[host_name(\""; + kernel_source += kernel_name; + kernel_source += "\")]] [[kernel]] decltype("; + kernel_source += template_def; + kernel_source += ") "; + kernel_source += template_def; + kernel_source += ";\n"; + } + + if (verbose) { + std::cout << "Generated source code for `" << name << "`:" << std::endl + << "```" << std::endl + << kernel_source << std::endl + << "```" << std::endl; + } + + return array::make_arrays( + std::move(output_shapes), + std::move(output_dtypes), + std::make_shared( + s, + std::move(kernel_name), + std::move(kernel_source), + grid, + threadgroup, + shape_infos, + ensure_row_contiguous, + init_value), + std::move(inputs)); + }; +} + void CustomKernel::eval_gpu( const std::vector& inputs, std::vector& outputs) { @@ -39,9 +353,23 @@ void CustomKernel::eval_gpu( } auto& d = metal::device(s.device); - const auto& lib_name = name_; - auto lib = - d.get_library(lib_name, [this] { return metal::utils() + source_; }); + + { + // Clear kernels from the device library cache if needed + auto& kernel_cache = cache(); + if (auto it = kernel_cache.libraries.find(name_); + it != kernel_cache.libraries.end()) { + if (it->second != source_) { + auto& d = metal::device(s.device); + d.clear_library(name_); + it->second = source_; + } + } else { + kernel_cache.libraries.emplace(name_, source_); + } + } + + auto lib = d.get_library(name_, [this] { return metal::utils() + source_; }); auto kernel = d.get_kernel(name_, lib); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index ebc3cc77f..425274361 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -295,7 +295,7 @@ void CommandEncoder::barrier() { Device::Device() { auto pool = new_scoped_memory_pool(); device_ = load_device(); - library_map_ = {{"mlx", load_default_library(device_)}}; + default_library_ = load_default_library(device_); arch_ = std::string(device_->architecture()->name()->utf8String()); auto arch = arch_.back(); switch (arch) { @@ -326,11 +326,11 @@ Device::Device() { Device::~Device() { auto pool = new_scoped_memory_pool(); - for (auto& k : kernel_map_) { - k.second->release(); - } - for (auto& l : library_map_) { - l.second->release(); + for (auto& [l, kernel_map] : library_kernels_) { + l->release(); + for (auto& [_, k] : kernel_map) { + k->release(); + } } stream_map_.clear(); device_->release(); @@ -474,13 +474,24 @@ CommandEncoder& Device::get_command_encoder(int index) { return *stream.encoder; } -void Device::register_library( - const std::string& lib_name, - const std::string& lib_path) { - if (auto it = library_map_.find(lib_name); it == library_map_.end()) { - auto new_lib = load_library(device_, lib_name, lib_path.c_str()); - library_map_.insert({lib_name, new_lib}); +MTL::Library* Device::get_library( + const std::string& name, + const std::string& path /* = "" */) { + { + std::shared_lock rlock(library_mtx_); + if (auto it = library_map_.find(name); it != library_map_.end()) { + return it->second; + } } + + std::unique_lock wlock(library_mtx_); + if (auto it = library_map_.find(name); it != library_map_.end()) { + return it->second; + } + + auto new_lib = load_library(device_, name, path.c_str()); + library_map_.insert({name, new_lib}); + return new_lib; } MTL::Library* Device::build_library_(const std::string& source_string) { @@ -649,6 +660,19 @@ MTL::Library* Device::get_library( return mtl_lib; } +void Device::clear_library(const std::string& name) { + std::unique_lock wlock(library_mtx_); + if (auto it = library_map_.find(name); it != library_map_.end()) { + auto kernel_map_it = library_kernels_.find(it->second); + for (auto& [_, kernel] : kernel_map_it->second) { + kernel->release(); + } + library_kernels_.erase(kernel_map_it); + it->second->release(); + library_map_.erase(it); + } +} + MTL::LinkedFunctions* Device::get_linked_functions_( const std::vector& funcs) { if (funcs.empty()) { @@ -679,6 +703,7 @@ MTL::ComputePipelineState* Device::get_kernel_( std::unique_lock wlock(kernel_mtx_); // Try loading again to avoid loading twice + auto& kernel_map_ = library_kernels_[mtl_lib]; if (auto it = kernel_map_.find(hash_name); it != kernel_map_.end()) { return it->second; } @@ -713,6 +738,7 @@ MTL::ComputePipelineState* Device::get_kernel( std::shared_lock lock(kernel_mtx_); // Look for cached kernel + auto& kernel_map_ = library_kernels_[mtl_lib]; if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) { return it->second; } @@ -722,23 +748,11 @@ MTL::ComputePipelineState* Device::get_kernel( MTL::ComputePipelineState* Device::get_kernel( const std::string& base_name, - const std::string& lib_name /* = "mlx" */, const std::string& hash_name /* = "" */, const MTLFCList& func_consts /* = {} */, const std::vector& linked_functions /* = {} */) { - const auto& kname = hash_name.size() == 0 ? base_name : hash_name; - { - // Multiple readers allowed - std::shared_lock lock(kernel_mtx_); - - // Look for cached kernel - if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) { - return it->second; - } - } - // Search for cached metal lib - MTL::Library* mtl_lib = get_library_(lib_name); - return get_kernel_(base_name, mtl_lib, kname, func_consts, linked_functions); + return get_kernel( + base_name, default_library_, hash_name, func_consts, linked_functions); } void Device::set_residency_set(const MTL::ResidencySet* residency_set) { diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index 660ba65e2..5bfcc6649 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -187,14 +187,16 @@ class Device { CommandEncoder& get_command_encoder(int index); void end_encoding(int index); - void register_library( - const std::string& lib_name, - const std::string& lib_path = ""); + MTL::Library* get_library( + const std::string& name, + const std::string& path = ""); MTL::Library* get_library( const std::string& name, const std::function& builder); + void clear_library(const std::string& name); + MTL::ComputePipelineState* get_kernel( const std::string& base_name, MTL::Library* mtl_lib, @@ -204,7 +206,6 @@ class Device { MTL::ComputePipelineState* get_kernel( const std::string& base_name, - const std::string& lib_name = "mlx", const std::string& hash_name = "", const MTLFCList& func_consts = {}, const std::vector& linked_functions = {}); @@ -258,10 +259,13 @@ class Device { std::unordered_map stream_map_; std::shared_mutex kernel_mtx_; - std::unordered_map kernel_map_; - std::shared_mutex library_mtx_; std::unordered_map library_map_; + MTL::Library* default_library_; + std::unordered_map< + MTL::Library*, + std::unordered_map> + library_kernels_; const MTL::ResidencySet* residency_set_{nullptr}; std::string arch_; int max_ops_per_buffer_; diff --git a/mlx/backend/metal/nojit_kernels.cpp b/mlx/backend/metal/nojit_kernels.cpp index 8da147971..b1478d33b 100644 --- a/mlx/backend/metal/nojit_kernels.cpp +++ b/mlx/backend/metal/nojit_kernels.cpp @@ -146,7 +146,7 @@ MTL::ComputePipelineState* get_steel_gemm_fused_kernel( int, int, int) { - return d.get_kernel(kernel_name, "mlx", hash_name, func_consts); + return d.get_kernel(kernel_name, hash_name, func_consts); } MTL::ComputePipelineState* get_steel_gemm_splitk_kernel( @@ -207,7 +207,7 @@ MTL::ComputePipelineState* get_steel_gemm_gather_kernel( int, int, bool) { - return d.get_kernel(kernel_name, "mlx", hash_name, func_consts); + return d.get_kernel(kernel_name, hash_name, func_consts); } MTL::ComputePipelineState* get_gemv_masked_kernel( @@ -259,7 +259,7 @@ MTL::ComputePipelineState* get_fft_kernel( const std::string& hash_name, const metal::MTLFCList& func_consts, const std::string&) { - return d.get_kernel(kernel_name, "mlx", hash_name, func_consts); + return d.get_kernel(kernel_name, hash_name, func_consts); } MTL::ComputePipelineState* get_quantized_kernel( @@ -283,7 +283,7 @@ MTL::ComputePipelineState* get_gather_qmm_kernel( int, int, bool) { - return d.get_kernel(kernel_name, "mlx", hash_name, func_consts); + return d.get_kernel(kernel_name, hash_name, func_consts); } } // namespace mlx::core diff --git a/mlx/backend/metal/normalization.cpp b/mlx/backend/metal/normalization.cpp index c0901ccec..781427824 100644 --- a/mlx/backend/metal/normalization.cpp +++ b/mlx/backend/metal/normalization.cpp @@ -172,7 +172,7 @@ void RMSNormVJP::eval_gpu( auto& compute_encoder = d.get_command_encoder(s.index); { - auto kernel = d.get_kernel(op_name, "mlx", hash_name, func_consts); + auto kernel = d.get_kernel(op_name, hash_name, func_consts); MTL::Size grid_dims, group_dims; if (axis_size <= looped_limit) { @@ -387,7 +387,7 @@ void LayerNormVJP::eval_gpu( }; { - auto kernel = d.get_kernel(op_name, "mlx", hash_name, func_consts); + auto kernel = d.get_kernel(op_name, hash_name, func_consts); MTL::Size grid_dims, group_dims; if (axis_size <= looped_limit) { diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index aad1a0018..25763be6d 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -73,7 +73,7 @@ void sdpa_full_self_attention_metal( std::string hash_name = kname.str(); auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts); + auto kernel = d.get_kernel(base_name, hash_name, func_consts); compute_encoder.set_compute_pipeline_state(kernel); const int NQ = (qL + bq - 1) / bq; @@ -180,7 +180,7 @@ void sdpa_vector( // Get the kernel auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname, "mlx", hash_name, func_consts); + auto kernel = d.get_kernel(kname, hash_name, func_consts); compute_encoder.set_compute_pipeline_state(kernel); // Set its arguments @@ -281,7 +281,7 @@ void sdpa_vector_2pass( // Get the kernel auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname, "mlx", hash_name, func_consts); + auto kernel = d.get_kernel(kname, hash_name, func_consts); compute_encoder.set_compute_pipeline_state(kernel); diff --git a/mlx/backend/no_gpu/primitives.cpp b/mlx/backend/no_gpu/primitives.cpp index 409aa2c89..849cbf83e 100644 --- a/mlx/backend/no_gpu/primitives.cpp +++ b/mlx/backend/no_gpu/primitives.cpp @@ -2,6 +2,7 @@ #include "mlx/primitives.h" #include "mlx/distributed/primitives.h" +#include "mlx/fast.h" #include "mlx/fast_primitives.h" #define NO_GPU_MULTI(func) \ @@ -155,6 +156,18 @@ NO_GPU_USE_FALLBACK(RoPE) NO_GPU(ScaledDotProductAttention) NO_GPU_MULTI(AffineQuantize) NO_GPU_MULTI(CustomKernel) + +MetalKernelFunction metal_kernel( + const std::string&, + const std::vector&, + const std::vector&, + const std::string&, + const std::string&, + bool ensure_row_contiguous, + bool atomic_outputs) { + throw std::runtime_error("[metal_kernel] No GPU back-end."); +} + } // namespace fast namespace distributed { diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 7a86f8d18..94fbc260f 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -1,11 +1,7 @@ // Copyright © 2023-2024 Apple Inc. #include -#include -#include #include -#include -#include "mlx/backend/common/compiled.h" #include "mlx/fast.h" #include "mlx/fast_primitives.h" #include "mlx/ops.h" @@ -1030,308 +1026,4 @@ std::vector AffineQuantize::output_shapes( } } -std::string write_signature( - std::string func_name, - const std::string& header, - const std::string& source, - const std::vector& input_names, - const std::vector& inputs, - const std::vector& output_names, - const std::vector& output_dtypes, - const std::vector>& template_args, - const std::vector& attributes, - const std::vector& shape_infos, - bool atomic_outputs) { - std::string kernel_source; - kernel_source.reserve(header.size() + source.size() + 16384); - kernel_source += header; - // Auto-generate a function signature based on `template_args` - // and the dtype/shape of the arrays passed as `inputs`. - if (!template_args.empty()) { - kernel_source += "template <"; - int i = 0; - for (const auto& [name, arg] : template_args) { - std::string param_type; - if (std::holds_alternative(arg)) { - param_type = "int"; - } else if (std::holds_alternative(arg)) { - param_type = "bool"; - } else if (std::holds_alternative(arg)) { - param_type = "typename"; - } - if (i > 0) { - kernel_source += ", "; - } - kernel_source += param_type; - kernel_source += " "; - kernel_source += name; - i++; - } - kernel_source += ">\n"; - } - kernel_source += "[[kernel]] void "; - kernel_source += func_name; - kernel_source += "(\n"; - - int index = 0; - constexpr int max_constant_array_size = 8; - // Add inputs - for (int i = 0; i < inputs.size(); ++i) { - const auto& name = input_names[i]; - const auto& arr = inputs[i]; - auto dtype = get_type_string(arr.dtype()); - std::string location = - arr.size() < max_constant_array_size ? "constant" : "device"; - std::string ref = arr.ndim() == 0 ? "&" : "*"; - kernel_source += " const "; - kernel_source += location; - kernel_source += " "; - kernel_source += dtype; - kernel_source += ref; - kernel_source += " "; - kernel_source += name; - kernel_source += " [[buffer("; - kernel_source += std::to_string(index); - kernel_source += ")]],\n"; - index++; - // Add input shape, strides and ndim if present in the source - if (arr.ndim() > 0) { - if (shape_infos[i].shape) { - kernel_source += - (" const constant int* " + name + "_shape [[buffer(" + - std::to_string(index) + ")]],\n"); - index++; - } - if (shape_infos[i].strides) { - kernel_source += - (" const constant int64_t* " + name + "_strides [[buffer(" + - std::to_string(index) + ")]],\n"); - index++; - } - if (shape_infos[i].ndim) { - kernel_source += - (" const constant int& " + name + "_ndim [[buffer(" + - std::to_string(index) + ")]],\n"); - index++; - } - } - } - // Add outputs - for (int i = 0; i < output_names.size(); ++i) { - const auto& name = output_names[i]; - const auto& dtype = output_dtypes[i]; - kernel_source += " device "; - auto type_string = get_type_string(dtype); - if (atomic_outputs) { - kernel_source += "atomic<"; - } - kernel_source += type_string; - if (atomic_outputs) { - kernel_source += ">"; - } - kernel_source += "* "; - kernel_source += name; - kernel_source += " [[buffer("; - kernel_source += std::to_string(index); - kernel_source += ")]]"; - if (index < inputs.size() + output_names.size() - 1 || - attributes.size() > 0) { - kernel_source += ",\n"; - } else { - kernel_source += ") {\n"; - } - index++; - } - - index = 0; - for (const auto& attr : attributes) { - kernel_source += attr; - if (index < attributes.size() - 1) { - kernel_source += ",\n"; - } else { - kernel_source += ") {\n"; - } - index++; - } - kernel_source += source; - kernel_source += "\n}\n"; - return kernel_source; -} - -std::string write_template( - const std::vector>& template_args) { - std::ostringstream template_def; - template_def << "<"; - int i = 0; - for (const auto& [name, arg] : template_args) { - if (i > 0) { - template_def << ", "; - } - if (std::holds_alternative(arg)) { - template_def << std::get(arg); - } else if (std::holds_alternative(arg)) { - template_def << std::get(arg); - } else if (std::holds_alternative(arg)) { - template_def << get_type_string(std::get(arg)); - } - i++; - } - template_def << ">"; - return template_def.str(); -} - -MetalKernelFunction metal_kernel( - const std::string& name, - const std::vector& input_names, - const std::vector& output_names, - const std::string& source, - const std::string& header /* = "" */, - bool ensure_row_contiguous /* = true */, - bool atomic_outputs /* = false */) { - if (output_names.empty()) { - throw std::invalid_argument( - "[metal_kernel] Must specify at least one output."); - } - std::vector shape_infos; - for (auto& n : input_names) { - CustomKernelShapeInfo shape_info; - shape_info.shape = source.find(n + "_shape") != std::string::npos; - shape_info.strides = source.find(n + "_strides") != std::string::npos; - shape_info.ndim = source.find(n + "_ndim") != std::string::npos; - shape_infos.push_back(shape_info); - } - const std::vector> metal_attributes = { - {"dispatch_quadgroups_per_threadgroup", "uint"}, - {"dispatch_simdgroups_per_threadgroup", "uint"}, - {"dispatch_threads_per_threadgroup", "uint3"}, - {"grid_origin", "uint3"}, - {"grid_size", "uint3"}, - {"quadgroup_index_in_threadgroup", "uint"}, - {"quadgroups_per_threadgroup", "uint"}, - {"simdgroup_index_in_threadgroup", "uint"}, - {"simdgroups_per_threadgroup", "uint"}, - {"thread_execution_width", "uint"}, - {"thread_index_in_quadgroup", "uint"}, - {"thread_index_in_simdgroup", "uint"}, - {"thread_index_in_threadgroup", "uint"}, - {"thread_position_in_grid", "uint3"}, - {"thread_position_in_threadgroup", "uint3"}, - {"threadgroup_position_in_grid", "uint3"}, - {"threadgroups_per_grid", "uint3"}, - {"threads_per_grid", "uint3"}, - {"threads_per_simdgroup", "uint"}, - {"threads_per_threadgroup", "uint3"}, - }; - - std::vector attributes; - for (const auto& [attr, dtype] : metal_attributes) { - if (source.find(attr) != std::string::npos) { - attributes.push_back(" " + dtype + " " + attr + " [[" + attr + "]]"); - } - } - auto now = std::chrono::system_clock::now(); - int64_t timestamp = std::chrono::duration_cast( - now.time_since_epoch()) - .count(); - - return [=, - shape_infos = std::move(shape_infos), - attributes = std::move(attributes)]( - const std::vector& inputs, - const std::vector& output_shapes, - const std::vector& output_dtypes, - std::tuple grid, - std::tuple threadgroup, - const std::vector>& - template_args = {}, - std::optional init_value = std::nullopt, - bool verbose = false, - StreamOrDevice s_ = {}) { - if (inputs.size() != input_names.size()) { - std::ostringstream msg; - msg << "[metal_kernel] Expected `inputs` to have size " - << input_names.size() << " but got size " << inputs.size() << "." - << std::endl; - throw std::invalid_argument(msg.str()); - } - if (output_shapes.size() != output_names.size()) { - std::ostringstream msg; - msg << "[metal_kernel] Expected `output_shapes` to have size " - << output_names.size() << " but got size " << output_shapes.size() - << "." << std::endl; - throw std::invalid_argument(msg.str()); - } - if (output_dtypes.size() != output_names.size()) { - std::ostringstream msg; - msg << "[metal_kernel] Expected `output_dtypes` to have size " - << output_names.size() << " but got size " << output_dtypes.size() - << "." << std::endl; - throw std::invalid_argument(msg.str()); - } - - auto s = to_stream(s_); - if (s.device != Device::gpu) { - throw std::invalid_argument("[metal_kernel] Only supports the GPU."); - } - - std::ostringstream func_name; - std::string template_def = ""; - std::string template_hash = ""; - if (!template_args.empty()) { - std::regex disallowed_chars("\\<|\\>|(, )"); - template_def = write_template(template_args); - template_hash = std::regex_replace(template_def, disallowed_chars, "_"); - template_hash.pop_back(); - } - func_name << "custom_kernel_" << name << "_" << template_hash << "_" - << timestamp; - std::string kernel_name = func_name.str(); - - std::string kernel_source = write_signature( - kernel_name, - header, - source, - input_names, - inputs, - output_names, - output_dtypes, - template_args, - attributes, - shape_infos, - atomic_outputs); - - if (!template_args.empty()) { - template_def = kernel_name + template_def; - kernel_source += "\ntemplate [[host_name(\""; - kernel_source += kernel_name; - kernel_source += "\")]] [[kernel]] decltype("; - kernel_source += template_def; - kernel_source += ") "; - kernel_source += template_def; - kernel_source += ";\n"; - } - - if (verbose) { - std::cout << "Generated source code for `" << name << "`:" << std::endl - << "```" << std::endl - << kernel_source << std::endl - << "```" << std::endl; - } - - return array::make_arrays( - std::move(output_shapes), - std::move(output_dtypes), - std::make_shared( - s, - std::move(kernel_name), - std::move(kernel_source), - grid, - threadgroup, - shape_infos, - ensure_row_contiguous, - init_value), - std::move(inputs)); - }; -} - } // namespace mlx::core::fast