diff --git a/mlx/backend/metal/custom_kernel.cpp b/mlx/backend/metal/custom_kernel.cpp index b63f35c1c..9fac2e89f 100644 --- a/mlx/backend/metal/custom_kernel.cpp +++ b/mlx/backend/metal/custom_kernel.cpp @@ -49,7 +49,7 @@ void CustomKernel::eval_gpu( int index = 0; for (int i = 0; i < checked_inputs.size(); i++) { const array& in = checked_inputs[i]; - auto shape_info = shape_infos_[i]; + auto& shape_info = shape_infos_[i]; compute_encoder.set_input_array(in, index); index++; if (in.ndim() > 0) { @@ -68,7 +68,7 @@ void CustomKernel::eval_gpu( } } } - for (array out : outputs) { + for (auto& out : outputs) { compute_encoder.set_output_array(out, index); index++; } diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 6a3b38218..a3fe1ea1e 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -916,21 +916,25 @@ array affine_dequantize( return fallback({w, scales, biases})[0]; } -void write_signature( +std::string write_signature( std::string func_name, + const std::string& header, const std::string& source, const std::vector& input_names, const std::vector& inputs, const std::vector& output_names, const std::vector& output_dtypes, const std::vector>& template_args, - std::vector& shape_infos, - bool atomic_outputs, - std::ostringstream& kernel_source) { + const std::vector& attributes, + const std::vector& shape_infos, + bool atomic_outputs) { + std::string kernel_source; + kernel_source.reserve(header.size() + source.size() + 16384); + kernel_source += header; // Auto-generate a function signature based on `template_args` // and the dtype/shape of the arrays passed as `inputs`. if (!template_args.empty()) { - kernel_source << "template <"; + kernel_source += "template <"; int i = 0; for (const auto& [name, arg] : template_args) { std::string param_type; @@ -942,44 +946,18 @@ void write_signature( param_type = "typename"; } if (i > 0) { - kernel_source << ", "; + kernel_source += ", "; } - kernel_source << param_type << " " << name; + kernel_source += param_type; + kernel_source += " "; + kernel_source += name; i++; } - kernel_source << ">" << std::endl; - } - kernel_source << "[[kernel]] void " << func_name << "(" << std::endl; - - // Metal attributes are automatically added to the arguments if present - const std::vector> metal_attributes = { - {"dispatch_quadgroups_per_threadgroup", "uint"}, - {"dispatch_simdgroups_per_threadgroup", "uint"}, - {"dispatch_threads_per_threadgroup", "uint3"}, - {"grid_origin", "uint3"}, - {"grid_size", "uint3"}, - {"quadgroup_index_in_threadgroup", "uint"}, - {"quadgroups_per_threadgroup", "uint"}, - {"simdgroup_index_in_threadgroup", "uint"}, - {"simdgroups_per_threadgroup", "uint"}, - {"thread_execution_width", "uint"}, - {"thread_index_in_quadgroup", "uint"}, - {"thread_index_in_simdgroup", "uint"}, - {"thread_index_in_threadgroup", "uint"}, - {"thread_position_in_grid", "uint3"}, - {"thread_position_in_threadgroup", "uint3"}, - {"threadgroup_position_in_grid", "uint3"}, - {"threadgroups_per_grid", "uint3"}, - {"threads_per_grid", "uint3"}, - {"threads_per_simdgroup", "uint"}, - {"threads_per_threadgroup", "uint3"}, - }; - std::vector> attrs; - for (const auto& [attr, dtype] : metal_attributes) { - if (source.find(attr) != std::string::npos) { - attrs.push_back({attr, dtype}); - } + kernel_source += ">\n"; } + kernel_source += "[[kernel]] void "; + kernel_source += func_name; + kernel_source += "(\n"; int index = 0; constexpr int max_constant_array_size = 8; @@ -988,69 +966,82 @@ void write_signature( const auto& name = input_names[i]; const auto& arr = inputs[i]; auto dtype = get_type_string(arr.dtype()); - bool is_constant = - arr.is_available() && arr.size() < max_constant_array_size; - std::string location = is_constant ? "constant" : "device"; + std::string location = + arr.size() < max_constant_array_size ? "constant" : "device"; std::string ref = arr.ndim() == 0 ? "&" : "*"; - kernel_source << " const " << location << " " << dtype << ref << " " - << name << " [[buffer(" << index << ")]]," << std::endl; + kernel_source += " const "; + kernel_source += location; + kernel_source += " "; + kernel_source += dtype; + kernel_source += ref; + kernel_source += " "; + kernel_source += name; + kernel_source += " [[buffer("; + kernel_source += std::to_string(index); + kernel_source += ")]],\n"; index++; // Add input shape, strides and ndim if present in the source - CustomKernelShapeInfo shape_info; if (arr.ndim() > 0) { - if (source.find(name + "_shape") != std::string::npos) { - kernel_source << " const constant int* " << name << "_shape [[buffer(" - << index << ")]]," << std::endl; - shape_info.shape = true; + if (shape_infos[i].shape) { + kernel_source += + (" const constant int* " + name + "_shape [[buffer(" + + std::to_string(index) + ")]],\n"); index++; } - if (source.find(name + "_strides") != std::string::npos) { - kernel_source << " const constant size_t* " << name - << "_strides [[buffer(" << index << ")]]," << std::endl; - shape_info.strides = true; + if (shape_infos[i].strides) { + kernel_source += + (" const constant size_t* " + name + "_strides [[buffer(" + + std::to_string(index) + ")]],\n"); index++; } - if (source.find(name + "_ndim") != std::string::npos) { - kernel_source << " const constant int& " << name << "_ndim [[buffer(" - << index << ")]]," << std::endl; - shape_info.ndim = true; + if (shape_infos[i].ndim) { + kernel_source += + (" const constant int& " + name + "_ndim [[buffer(" + + std::to_string(index) + ")]],\n"); index++; } } - shape_infos.push_back(shape_info); } // Add outputs for (int i = 0; i < output_names.size(); ++i) { const auto& name = output_names[i]; const auto& dtype = output_dtypes[i]; - kernel_source << " device "; + kernel_source += " device "; auto type_string = get_type_string(dtype); if (atomic_outputs) { - kernel_source << "atomic<" << type_string << ">"; - } else { - kernel_source << type_string; + kernel_source += "atomic<"; } - kernel_source << "* " << name << " [[buffer(" << index << ")]]"; - if (index < inputs.size() + output_names.size() - 1 || attrs.size() > 0) { - kernel_source << "," << std::endl; + kernel_source += type_string; + if (atomic_outputs) { + kernel_source += ">"; + } + kernel_source += "* "; + kernel_source += name; + kernel_source += " [[buffer("; + kernel_source += std::to_string(index); + kernel_source += ")]]"; + if (index < inputs.size() + output_names.size() - 1 || + attributes.size() > 0) { + kernel_source += ",\n"; } else { - kernel_source << ") {" << std::endl; + kernel_source += ") {\n"; } index++; } - // Add metal attributes e.g. `threadgroup_index_in_grid` + index = 0; - for (const auto& [attr, dtype] : attrs) { - kernel_source << " " << dtype << " " << attr << " [[" << attr << "]]"; - if (index < attrs.size() - 1) { - kernel_source << "," << std::endl; + for (const auto& attr : attributes) { + kernel_source += attr; + if (index < attributes.size() - 1) { + kernel_source += ",\n"; } else { - kernel_source << ") {" << std::endl; + kernel_source += ") {\n"; } index++; } - kernel_source << source << std::endl; - kernel_source << "}" << std::endl; + kernel_source += source; + kernel_source += "\n}\n"; + return kernel_source; } std::string write_template( @@ -1087,8 +1078,48 @@ MetalKernelFunction metal_kernel( throw std::invalid_argument( "[metal_kernel] Must specify at least one output."); } + std::vector shape_infos; + for (auto& n : input_names) { + CustomKernelShapeInfo shape_info; + shape_info.shape = source.find(n + "_shape") != std::string::npos; + shape_info.strides = source.find(n + "_strides") != std::string::npos; + shape_info.ndim = source.find(n + "_ndim") != std::string::npos; + shape_infos.push_back(shape_info); + } + const std::vector> metal_attributes = { + {"dispatch_quadgroups_per_threadgroup", "uint"}, + {"dispatch_simdgroups_per_threadgroup", "uint"}, + {"dispatch_threads_per_threadgroup", "uint3"}, + {"grid_origin", "uint3"}, + {"grid_size", "uint3"}, + {"quadgroup_index_in_threadgroup", "uint"}, + {"quadgroups_per_threadgroup", "uint"}, + {"simdgroup_index_in_threadgroup", "uint"}, + {"simdgroups_per_threadgroup", "uint"}, + {"thread_execution_width", "uint"}, + {"thread_index_in_quadgroup", "uint"}, + {"thread_index_in_simdgroup", "uint"}, + {"thread_index_in_threadgroup", "uint"}, + {"thread_position_in_grid", "uint3"}, + {"thread_position_in_threadgroup", "uint3"}, + {"threadgroup_position_in_grid", "uint3"}, + {"threadgroups_per_grid", "uint3"}, + {"threads_per_grid", "uint3"}, + {"threads_per_simdgroup", "uint"}, + {"threads_per_threadgroup", "uint3"}, + }; - return [=](const std::vector& inputs, + std::vector attributes; + for (const auto& [attr, dtype] : metal_attributes) { + if (source.find(attr) != std::string::npos) { + attributes.push_back(" " + dtype + " " + attr + " [[" + attr + "]]"); + } + } + + return [=, + shape_infos = std::move(shape_infos), + attributes = std::move(attributes)]( + const std::vector& inputs, const std::vector>& output_shapes, const std::vector& output_dtypes, std::tuple grid, @@ -1126,7 +1157,6 @@ MetalKernelFunction metal_kernel( } std::ostringstream func_name; - std::string template_def = ""; std::string hash_key = ""; if (!template_args.empty()) { @@ -1135,54 +1165,53 @@ MetalKernelFunction metal_kernel( hash_key = std::regex_replace(template_def, disallowed_chars, "_"); hash_key.pop_back(); } - func_name << "custom_kernel_" << name << hash_key; std::string kernel_name = func_name.str(); - std::ostringstream kernel_source; - kernel_source << header << std::endl; - - std::vector shape_infos; - write_signature( - func_name.str(), + std::string kernel_source = write_signature( + kernel_name, + header, source, input_names, inputs, output_names, output_dtypes, template_args, + attributes, shape_infos, - atomic_outputs, - kernel_source); + atomic_outputs); if (!template_args.empty()) { - template_def = func_name.str() + template_def; - kernel_source << std::endl - << "template [[host_name(\"" << kernel_name - << "\")]] [[kernel]] decltype(" << template_def << ") " - << template_def << ";" << std::endl; + template_def = kernel_name + template_def; + kernel_source += "\ntemplate [[host_name(\""; + kernel_source += kernel_name; + kernel_source += "\")]] [[kernel]] decltype("; + kernel_source += template_def; + kernel_source += ") "; + kernel_source += template_def; + kernel_source += ";\n"; } if (verbose) { std::cout << "Generated source code for `" << name << "`:" << std::endl << "```" << std::endl - << kernel_source.str() << std::endl + << kernel_source << std::endl << "```" << std::endl; } return array::make_arrays( - output_shapes, - output_dtypes, + std::move(output_shapes), + std::move(output_dtypes), std::make_shared( s, - kernel_name, - kernel_source.str(), + std::move(kernel_name), + std::move(kernel_source), grid, threadgroup, shape_infos, ensure_row_contiguous, init_value), - inputs); + std::move(inputs)); }; } diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index 1d01610f3..9233a1628 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -262,11 +262,11 @@ class CustomKernel : public Primitive { bool ensure_row_contiguous, std::optional init_value) : Primitive(stream), - source_(source), - name_(name), + source_(std::move(source)), + name_(std::move(name)), grid_(grid), threadgroup_(threadgroup), - shape_infos_(shape_infos), + shape_infos_(std::move(shape_infos)), ensure_row_contiguous_(ensure_row_contiguous), init_value_(init_value) {}