mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 19:38:16 +08:00
Custom Metal Kernels from Python (#1325)
* start * simple kernels working * restructure * inverse example working * docs + fixes * missing file * fix imports * address comments * add docs + fix test * Review comments + refactor to a single function * update docs * remove hashing * fix contig bug in test * back to a class * trailing whitespace * fix tests * match c++ and python apis * add link + make args kw_only
This commit is contained in:
@@ -131,6 +131,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
|
84
mlx/backend/metal/custom_kernel.cpp
Normal file
84
mlx/backend/metal/custom_kernel.cpp
Normal file
@@ -0,0 +1,84 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
namespace mlx::core::fast {
|
||||
|
||||
void CustomKernel::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto& s = stream();
|
||||
|
||||
for (auto& out : outputs) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
|
||||
std::vector<array> copies;
|
||||
|
||||
auto check_input = [&copies, &s, this](const array& x) -> const array {
|
||||
bool no_copy = x.flags().row_contiguous;
|
||||
if (!ensure_row_contiguous_ || no_copy) {
|
||||
return x;
|
||||
} else {
|
||||
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
|
||||
copy_gpu(x, copies.back(), CopyType::General, s);
|
||||
return copies.back();
|
||||
}
|
||||
};
|
||||
std::vector<const array> checked_inputs;
|
||||
for (const array& in : inputs) {
|
||||
checked_inputs.push_back(check_input(in));
|
||||
}
|
||||
|
||||
auto& d = metal::device(s.device);
|
||||
const auto& lib_name = name_;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
lib = d.get_library(lib_name, metal::utils() + source_);
|
||||
}
|
||||
auto kernel = d.get_kernel(name_, lib);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
int index = 0;
|
||||
for (int i = 0; i < checked_inputs.size(); i++) {
|
||||
const array& in = checked_inputs[i];
|
||||
auto shape_info = shape_infos_[i];
|
||||
compute_encoder.set_input_array(in, index);
|
||||
index++;
|
||||
if (in.ndim() > 0) {
|
||||
int ndim = in.ndim();
|
||||
if (shape_info.shape) {
|
||||
set_vector_bytes(compute_encoder, in.shape(), ndim, index);
|
||||
index++;
|
||||
}
|
||||
if (shape_info.strides) {
|
||||
set_vector_bytes(compute_encoder, in.strides(), ndim, index);
|
||||
index++;
|
||||
}
|
||||
if (shape_info.ndim) {
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), index);
|
||||
index++;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (array out : outputs) {
|
||||
compute_encoder.set_output_array(out, index);
|
||||
index++;
|
||||
}
|
||||
|
||||
const auto [tx, ty, tz] = threadgroup_;
|
||||
MTL::Size group_dims = MTL::Size(tx, ty, tz);
|
||||
const auto [gx, gy, gz] = grid_;
|
||||
MTL::Size grid_dims = MTL::Size(gx, gy, gz);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
if (!copies.empty()) {
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fast
|
@@ -119,6 +119,7 @@ NO_GPU_MULTI(RMSNormVJP)
|
||||
NO_GPU_MULTI(RoPE)
|
||||
NO_GPU(ScaledDotProductAttention)
|
||||
NO_GPU_MULTI(AffineQuantize)
|
||||
NO_GPU_MULTI(CustomKernel)
|
||||
} // namespace fast
|
||||
|
||||
} // namespace mlx::core
|
||||
|
270
mlx/fast.cpp
270
mlx/fast.cpp
@@ -1,7 +1,10 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <regex>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/fast.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/ops.h"
|
||||
@@ -913,4 +916,271 @@ array affine_dequantize(
|
||||
return fallback({w, scales, biases})[0];
|
||||
}
|
||||
|
||||
void validate_output_shapes(
|
||||
std::map<std::string, std::vector<int>> output_shapes,
|
||||
std::map<std::string, Dtype> output_dtypes) {
|
||||
// Make sure output shapes and dtypes have the same keys
|
||||
bool validated = true;
|
||||
if (output_shapes.size() == 0) {
|
||||
throw std::invalid_argument(
|
||||
"[metal_kernel] Must specify at least one output.");
|
||||
}
|
||||
if (output_shapes.size() != output_dtypes.size()) {
|
||||
validated = false;
|
||||
} else {
|
||||
for (const auto& kv : output_shapes) {
|
||||
if (output_dtypes.find(kv.first) == output_dtypes.end()) {
|
||||
validated = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!validated) {
|
||||
throw std::invalid_argument(
|
||||
"[metal_kernel] `output_shapes` and `output_dtypes` must have the same keys.");
|
||||
}
|
||||
}
|
||||
|
||||
void write_signature(
|
||||
std::string func_name,
|
||||
std::string& source,
|
||||
std::map<std::string, array>& inputs,
|
||||
std::map<std::string, std::vector<int>>& output_shapes,
|
||||
std::map<std::string, Dtype>& output_dtypes,
|
||||
std::optional<std::map<std::string, TemplateArg>> template_args,
|
||||
std::vector<CustomKernelShapeInfo>& shape_infos,
|
||||
std::ostringstream& kernel_source) {
|
||||
// Auto-generate a function signature based on `template_args`
|
||||
// and the dtype/shape of the arrays passed as `inputs`.
|
||||
if (template_args && template_args.value().size() > 0) {
|
||||
kernel_source << "template <";
|
||||
int i = 0;
|
||||
for (const auto& [name, arg] : template_args.value()) {
|
||||
std::string param_type;
|
||||
if (std::holds_alternative<int>(arg)) {
|
||||
param_type = "int";
|
||||
} else if (std::holds_alternative<bool>(arg)) {
|
||||
param_type = "bool";
|
||||
} else if (std::holds_alternative<Dtype>(arg)) {
|
||||
param_type = "typename";
|
||||
}
|
||||
if (i > 0) {
|
||||
kernel_source << ", ";
|
||||
}
|
||||
kernel_source << param_type << " " << name;
|
||||
i++;
|
||||
}
|
||||
kernel_source << ">" << std::endl;
|
||||
}
|
||||
kernel_source << "[[kernel]] void " << func_name << "(" << std::endl;
|
||||
|
||||
// Metal attributes are automatically added to the arguments if present
|
||||
const std::vector<std::pair<std::string, std::string>> metal_attributes = {
|
||||
{"dispatch_quadgroups_per_threadgroup", "uint"},
|
||||
{"dispatch_simdgroups_per_threadgroup", "uint"},
|
||||
{"dispatch_threads_per_threadgroup", "uint3"},
|
||||
{"grid_origin", "uint3"},
|
||||
{"grid_size", "uint3"},
|
||||
{"quadgroup_index_in_threadgroup", "uint"},
|
||||
{"quadgroups_per_threadgroup", "uint"},
|
||||
{"simdgroup_index_in_threadgroup", "uint"},
|
||||
{"simdgroups_per_threadgroup", "uint"},
|
||||
{"thread_execution_width", "uint"},
|
||||
{"thread_index_in_quadgroup", "uint"},
|
||||
{"thread_index_in_simdgroup", "uint"},
|
||||
{"thread_index_in_threadgroup", "uint"},
|
||||
{"thread_position_in_grid", "uint3"},
|
||||
{"thread_position_in_threadgroup", "uint3"},
|
||||
{"threadgroup_position_in_grid", "uint3"},
|
||||
{"threadgroups_per_grid", "uint3"},
|
||||
{"threads_per_grid", "uint3"},
|
||||
{"threads_per_simdgroup", "uint"},
|
||||
{"thread_per_threadgroup", "uint3"},
|
||||
};
|
||||
std::vector<std::pair<std::string, std::string>> attrs;
|
||||
for (const auto& [attr, dtype] : metal_attributes) {
|
||||
if (source.find(attr) != std::string::npos) {
|
||||
attrs.push_back({attr, dtype});
|
||||
}
|
||||
}
|
||||
|
||||
int index = 0;
|
||||
constexpr int max_constant_array_size = 8;
|
||||
// Add inputs
|
||||
for (const auto& [name, arr] : inputs) {
|
||||
auto dtype = get_type_string(arr.dtype());
|
||||
bool is_constant =
|
||||
arr.is_available() && arr.size() < max_constant_array_size;
|
||||
std::string location = is_constant ? "constant" : "device";
|
||||
std::string ref = arr.ndim() == 0 ? "&" : "*";
|
||||
kernel_source << " const " << location << " " << dtype << ref << " "
|
||||
<< name << " [[buffer(" << index << ")]]," << std::endl;
|
||||
index++;
|
||||
// Add input shape, strides and ndim if present in the source
|
||||
CustomKernelShapeInfo shape_info;
|
||||
if (arr.ndim() > 0) {
|
||||
if (source.find(name + "_shape") != std::string::npos) {
|
||||
kernel_source << " const constant int* " << name << "_shape [[buffer("
|
||||
<< index << ")]]," << std::endl;
|
||||
shape_info.shape = true;
|
||||
index++;
|
||||
}
|
||||
if (source.find(name + "_strides") != std::string::npos) {
|
||||
kernel_source << " const constant size_t* " << name
|
||||
<< "_strides [[buffer(" << index << ")]]," << std::endl;
|
||||
shape_info.strides = true;
|
||||
index++;
|
||||
}
|
||||
if (source.find(name + "_ndim") != std::string::npos) {
|
||||
kernel_source << " const constant int& " << name << "_ndim [[buffer("
|
||||
<< index << ")]]," << std::endl;
|
||||
shape_info.ndim = true;
|
||||
index++;
|
||||
}
|
||||
}
|
||||
shape_infos.push_back(shape_info);
|
||||
}
|
||||
// Add outputs
|
||||
for (const auto& [name, dtype] : output_dtypes) {
|
||||
kernel_source << " device " << get_type_string(dtype) << "* " << name
|
||||
<< " [[buffer(" << index << ")]]";
|
||||
if (index < inputs.size() + output_shapes.size() - 1 || attrs.size() > 0) {
|
||||
kernel_source << "," << std::endl;
|
||||
} else {
|
||||
kernel_source << ") {" << std::endl;
|
||||
}
|
||||
index++;
|
||||
}
|
||||
// Add metal attributes e.g. `threadgroup_index_in_grid`
|
||||
for (const auto& [attr, dtype] : attrs) {
|
||||
kernel_source << " " << dtype << " " << attr << " [[" << attr << "]]";
|
||||
if (index < attrs.size() - 1) {
|
||||
kernel_source << "," << std::endl;
|
||||
} else {
|
||||
kernel_source << ") {" << std::endl;
|
||||
}
|
||||
}
|
||||
kernel_source << source << std::endl;
|
||||
kernel_source << "}" << std::endl;
|
||||
}
|
||||
|
||||
std::string write_template(std::map<std::string, TemplateArg>& template_args) {
|
||||
std::ostringstream template_def;
|
||||
template_def << "<";
|
||||
int i = 0;
|
||||
for (const auto& [name, arg] : template_args) {
|
||||
if (i > 0) {
|
||||
template_def << ", ";
|
||||
}
|
||||
if (std::holds_alternative<int>(arg)) {
|
||||
template_def << std::get<int>(arg);
|
||||
} else if (std::holds_alternative<bool>(arg)) {
|
||||
template_def << std::get<bool>(arg);
|
||||
} else if (std::holds_alternative<Dtype>(arg)) {
|
||||
template_def << get_type_string(std::get<Dtype>(arg));
|
||||
}
|
||||
i++;
|
||||
}
|
||||
template_def << ">";
|
||||
return template_def.str();
|
||||
}
|
||||
|
||||
std::map<std::string, array> MetalKernel::operator()(
|
||||
std::map<std::string, array>& inputs,
|
||||
std::map<std::string, std::vector<int>> output_shapes,
|
||||
std::map<std::string, Dtype> output_dtypes,
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
std::optional<std::map<std::string, TemplateArg>> template_args,
|
||||
bool verbose,
|
||||
StreamOrDevice s_) {
|
||||
validate_output_shapes(output_shapes, output_dtypes);
|
||||
|
||||
auto s = to_stream(s_);
|
||||
if (s.device != Device::gpu) {
|
||||
throw std::invalid_argument(
|
||||
"[metal_kernel] MetalKernel only works on GPU.");
|
||||
}
|
||||
|
||||
std::ostringstream kernel_source;
|
||||
std::ostringstream func_name;
|
||||
|
||||
std::string template_def = "";
|
||||
bool needs_template = template_args && template_args.value().size() > 0;
|
||||
std::string hash_key = "";
|
||||
if (needs_template) {
|
||||
std::regex disallowed_chars("\\<|\\>|(, )");
|
||||
template_def = write_template(template_args.value());
|
||||
hash_key = std::regex_replace(template_def, disallowed_chars, "_");
|
||||
hash_key.pop_back();
|
||||
}
|
||||
|
||||
func_name << "custom_kernel_" << name_ << hash_key;
|
||||
std::string kernel_name = func_name.str();
|
||||
|
||||
std::vector<CustomKernelShapeInfo> shape_infos;
|
||||
write_signature(
|
||||
func_name.str(),
|
||||
source_,
|
||||
inputs,
|
||||
output_shapes,
|
||||
output_dtypes,
|
||||
template_args,
|
||||
shape_infos,
|
||||
kernel_source);
|
||||
|
||||
if (needs_template) {
|
||||
template_def = func_name.str() + template_def;
|
||||
kernel_source << std::endl
|
||||
<< "template [[host_name(\"" << kernel_name
|
||||
<< "\")]] [[kernel]] decltype(" << template_def << ") "
|
||||
<< template_def << ";" << std::endl;
|
||||
}
|
||||
|
||||
if (verbose) {
|
||||
std::cout << "Generated source code for `" << name_ << "`:" << std::endl
|
||||
<< "```" << std::endl
|
||||
<< kernel_source.str() << std::endl
|
||||
<< "```" << std::endl;
|
||||
}
|
||||
|
||||
std::vector<array> in_arrs;
|
||||
for (const auto& kv : inputs) {
|
||||
in_arrs.push_back(kv.second);
|
||||
}
|
||||
|
||||
std::vector<std::string> out_keys;
|
||||
std::vector<std::vector<int>> out_shapes;
|
||||
for (const auto& [name, shape] : output_shapes) {
|
||||
out_keys.push_back(name);
|
||||
out_shapes.push_back(shape);
|
||||
}
|
||||
|
||||
std::vector<Dtype> out_dtypes;
|
||||
for (const auto& kv : output_dtypes) {
|
||||
out_dtypes.push_back(kv.second);
|
||||
}
|
||||
|
||||
std::map<std::string, array> outputs;
|
||||
auto outputs_vec = array::make_arrays(
|
||||
out_shapes,
|
||||
out_dtypes,
|
||||
std::make_shared<CustomKernel>(
|
||||
s,
|
||||
kernel_name,
|
||||
kernel_source.str(),
|
||||
grid,
|
||||
threadgroup,
|
||||
shape_infos,
|
||||
ensure_row_contiguous_),
|
||||
in_arrs);
|
||||
|
||||
int i = 0;
|
||||
for (const auto& key : out_keys) {
|
||||
outputs.insert({key, outputs_vec[i]});
|
||||
i++;
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fast
|
||||
|
29
mlx/fast.h
29
mlx/fast.h
@@ -2,6 +2,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include <optional>
|
||||
|
||||
#include "mlx/utils.h"
|
||||
@@ -63,4 +64,32 @@ array affine_dequantize(
|
||||
int bits = 4,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
typedef std::variant<int, bool, Dtype> TemplateArg;
|
||||
|
||||
class MetalKernel {
|
||||
public:
|
||||
MetalKernel(
|
||||
const std::string& name,
|
||||
const std::string& source,
|
||||
bool ensure_row_contiguous)
|
||||
: name_(name),
|
||||
source_(source),
|
||||
ensure_row_contiguous_(ensure_row_contiguous) {}
|
||||
|
||||
std::map<std::string, array> operator()(
|
||||
std::map<std::string, array>& inputs,
|
||||
std::map<std::string, std::vector<int>> output_shapes,
|
||||
std::map<std::string, Dtype> output_dtypes,
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
std::optional<std::map<std::string, TemplateArg>> template_args =
|
||||
std::nullopt,
|
||||
bool verbose = false,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
std::string source_;
|
||||
bool ensure_row_contiguous_ = true;
|
||||
};
|
||||
} // namespace mlx::core::fast
|
||||
|
@@ -242,4 +242,47 @@ class AffineQuantize : public Custom {
|
||||
bool dequantize_;
|
||||
};
|
||||
|
||||
struct CustomKernelShapeInfo {
|
||||
bool shape = false;
|
||||
bool strides = false;
|
||||
bool ndim = false;
|
||||
};
|
||||
|
||||
class CustomKernel : public Primitive {
|
||||
public:
|
||||
CustomKernel(
|
||||
Stream stream,
|
||||
std::string name,
|
||||
std::string source,
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
std::vector<CustomKernelShapeInfo> shape_infos,
|
||||
bool ensure_row_contiguous)
|
||||
: Primitive(stream),
|
||||
source_(source),
|
||||
name_(name),
|
||||
grid_(grid),
|
||||
threadgroup_(threadgroup),
|
||||
shape_infos_(shape_infos),
|
||||
ensure_row_contiguous_(ensure_row_contiguous) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
throw std::runtime_error("Custom Metal kernels only run on GPU.");
|
||||
}
|
||||
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
DEFINE_PRINT(CustomKernel);
|
||||
|
||||
private:
|
||||
std::string source_;
|
||||
std::string name_;
|
||||
std::tuple<int, int, int> grid_;
|
||||
std::tuple<int, int, int> threadgroup_;
|
||||
std::vector<CustomKernelShapeInfo> shape_infos_;
|
||||
bool ensure_row_contiguous_;
|
||||
};
|
||||
|
||||
} // namespace mlx::core::fast
|
||||
|
Reference in New Issue
Block a user