Custom Metal Kernels from Python (#1325)

* start

* simple kernels working

* restructure

* inverse example working

* docs + fixes

* missing file

* fix imports

* address comments

* add docs + fix test

* Review comments + refactor to a single function

* update docs

* remove hashing

* fix contig bug in test

* back to a class

* trailing whitespace

* fix tests

* match c++ and python apis

* add link + make args kw_only
This commit is contained in:
Alex Barron
2024-08-22 13:46:29 -07:00
committed by GitHub
parent df3233454d
commit 0fd2a1f4b0
12 changed files with 793 additions and 4 deletions

View File

@@ -131,6 +131,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp

View File

@@ -0,0 +1,84 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/fast_primitives.h"
namespace mlx::core::fast {
void CustomKernel::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& s = stream();
for (auto& out : outputs) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
std::vector<array> copies;
auto check_input = [&copies, &s, this](const array& x) -> const array {
bool no_copy = x.flags().row_contiguous;
if (!ensure_row_contiguous_ || no_copy) {
return x;
} else {
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
copy_gpu(x, copies.back(), CopyType::General, s);
return copies.back();
}
};
std::vector<const array> checked_inputs;
for (const array& in : inputs) {
checked_inputs.push_back(check_input(in));
}
auto& d = metal::device(s.device);
const auto& lib_name = name_;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
lib = d.get_library(lib_name, metal::utils() + source_);
}
auto kernel = d.get_kernel(name_, lib);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
int index = 0;
for (int i = 0; i < checked_inputs.size(); i++) {
const array& in = checked_inputs[i];
auto shape_info = shape_infos_[i];
compute_encoder.set_input_array(in, index);
index++;
if (in.ndim() > 0) {
int ndim = in.ndim();
if (shape_info.shape) {
set_vector_bytes(compute_encoder, in.shape(), ndim, index);
index++;
}
if (shape_info.strides) {
set_vector_bytes(compute_encoder, in.strides(), ndim, index);
index++;
}
if (shape_info.ndim) {
compute_encoder->setBytes(&ndim, sizeof(int), index);
index++;
}
}
}
for (array out : outputs) {
compute_encoder.set_output_array(out, index);
index++;
}
const auto [tx, ty, tz] = threadgroup_;
MTL::Size group_dims = MTL::Size(tx, ty, tz);
const auto [gx, gy, gz] = grid_;
MTL::Size grid_dims = MTL::Size(gx, gy, gz);
compute_encoder->dispatchThreads(grid_dims, group_dims);
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
}
}
} // namespace mlx::core::fast

View File

@@ -119,6 +119,7 @@ NO_GPU_MULTI(RMSNormVJP)
NO_GPU_MULTI(RoPE)
NO_GPU(ScaledDotProductAttention)
NO_GPU_MULTI(AffineQuantize)
NO_GPU_MULTI(CustomKernel)
} // namespace fast
} // namespace mlx::core

View File

@@ -1,7 +1,10 @@
// Copyright © 2023-2024 Apple Inc.
#include <cassert>
#include <iostream>
#include <numeric>
#include <regex>
#include "mlx/backend/common/compiled.h"
#include "mlx/fast.h"
#include "mlx/fast_primitives.h"
#include "mlx/ops.h"
@@ -913,4 +916,271 @@ array affine_dequantize(
return fallback({w, scales, biases})[0];
}
void validate_output_shapes(
std::map<std::string, std::vector<int>> output_shapes,
std::map<std::string, Dtype> output_dtypes) {
// Make sure output shapes and dtypes have the same keys
bool validated = true;
if (output_shapes.size() == 0) {
throw std::invalid_argument(
"[metal_kernel] Must specify at least one output.");
}
if (output_shapes.size() != output_dtypes.size()) {
validated = false;
} else {
for (const auto& kv : output_shapes) {
if (output_dtypes.find(kv.first) == output_dtypes.end()) {
validated = false;
break;
}
}
}
if (!validated) {
throw std::invalid_argument(
"[metal_kernel] `output_shapes` and `output_dtypes` must have the same keys.");
}
}
void write_signature(
std::string func_name,
std::string& source,
std::map<std::string, array>& inputs,
std::map<std::string, std::vector<int>>& output_shapes,
std::map<std::string, Dtype>& output_dtypes,
std::optional<std::map<std::string, TemplateArg>> template_args,
std::vector<CustomKernelShapeInfo>& shape_infos,
std::ostringstream& kernel_source) {
// Auto-generate a function signature based on `template_args`
// and the dtype/shape of the arrays passed as `inputs`.
if (template_args && template_args.value().size() > 0) {
kernel_source << "template <";
int i = 0;
for (const auto& [name, arg] : template_args.value()) {
std::string param_type;
if (std::holds_alternative<int>(arg)) {
param_type = "int";
} else if (std::holds_alternative<bool>(arg)) {
param_type = "bool";
} else if (std::holds_alternative<Dtype>(arg)) {
param_type = "typename";
}
if (i > 0) {
kernel_source << ", ";
}
kernel_source << param_type << " " << name;
i++;
}
kernel_source << ">" << std::endl;
}
kernel_source << "[[kernel]] void " << func_name << "(" << std::endl;
// Metal attributes are automatically added to the arguments if present
const std::vector<std::pair<std::string, std::string>> metal_attributes = {
{"dispatch_quadgroups_per_threadgroup", "uint"},
{"dispatch_simdgroups_per_threadgroup", "uint"},
{"dispatch_threads_per_threadgroup", "uint3"},
{"grid_origin", "uint3"},
{"grid_size", "uint3"},
{"quadgroup_index_in_threadgroup", "uint"},
{"quadgroups_per_threadgroup", "uint"},
{"simdgroup_index_in_threadgroup", "uint"},
{"simdgroups_per_threadgroup", "uint"},
{"thread_execution_width", "uint"},
{"thread_index_in_quadgroup", "uint"},
{"thread_index_in_simdgroup", "uint"},
{"thread_index_in_threadgroup", "uint"},
{"thread_position_in_grid", "uint3"},
{"thread_position_in_threadgroup", "uint3"},
{"threadgroup_position_in_grid", "uint3"},
{"threadgroups_per_grid", "uint3"},
{"threads_per_grid", "uint3"},
{"threads_per_simdgroup", "uint"},
{"thread_per_threadgroup", "uint3"},
};
std::vector<std::pair<std::string, std::string>> attrs;
for (const auto& [attr, dtype] : metal_attributes) {
if (source.find(attr) != std::string::npos) {
attrs.push_back({attr, dtype});
}
}
int index = 0;
constexpr int max_constant_array_size = 8;
// Add inputs
for (const auto& [name, arr] : inputs) {
auto dtype = get_type_string(arr.dtype());
bool is_constant =
arr.is_available() && arr.size() < max_constant_array_size;
std::string location = is_constant ? "constant" : "device";
std::string ref = arr.ndim() == 0 ? "&" : "*";
kernel_source << " const " << location << " " << dtype << ref << " "
<< name << " [[buffer(" << index << ")]]," << std::endl;
index++;
// Add input shape, strides and ndim if present in the source
CustomKernelShapeInfo shape_info;
if (arr.ndim() > 0) {
if (source.find(name + "_shape") != std::string::npos) {
kernel_source << " const constant int* " << name << "_shape [[buffer("
<< index << ")]]," << std::endl;
shape_info.shape = true;
index++;
}
if (source.find(name + "_strides") != std::string::npos) {
kernel_source << " const constant size_t* " << name
<< "_strides [[buffer(" << index << ")]]," << std::endl;
shape_info.strides = true;
index++;
}
if (source.find(name + "_ndim") != std::string::npos) {
kernel_source << " const constant int& " << name << "_ndim [[buffer("
<< index << ")]]," << std::endl;
shape_info.ndim = true;
index++;
}
}
shape_infos.push_back(shape_info);
}
// Add outputs
for (const auto& [name, dtype] : output_dtypes) {
kernel_source << " device " << get_type_string(dtype) << "* " << name
<< " [[buffer(" << index << ")]]";
if (index < inputs.size() + output_shapes.size() - 1 || attrs.size() > 0) {
kernel_source << "," << std::endl;
} else {
kernel_source << ") {" << std::endl;
}
index++;
}
// Add metal attributes e.g. `threadgroup_index_in_grid`
for (const auto& [attr, dtype] : attrs) {
kernel_source << " " << dtype << " " << attr << " [[" << attr << "]]";
if (index < attrs.size() - 1) {
kernel_source << "," << std::endl;
} else {
kernel_source << ") {" << std::endl;
}
}
kernel_source << source << std::endl;
kernel_source << "}" << std::endl;
}
std::string write_template(std::map<std::string, TemplateArg>& template_args) {
std::ostringstream template_def;
template_def << "<";
int i = 0;
for (const auto& [name, arg] : template_args) {
if (i > 0) {
template_def << ", ";
}
if (std::holds_alternative<int>(arg)) {
template_def << std::get<int>(arg);
} else if (std::holds_alternative<bool>(arg)) {
template_def << std::get<bool>(arg);
} else if (std::holds_alternative<Dtype>(arg)) {
template_def << get_type_string(std::get<Dtype>(arg));
}
i++;
}
template_def << ">";
return template_def.str();
}
std::map<std::string, array> MetalKernel::operator()(
std::map<std::string, array>& inputs,
std::map<std::string, std::vector<int>> output_shapes,
std::map<std::string, Dtype> output_dtypes,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
std::optional<std::map<std::string, TemplateArg>> template_args,
bool verbose,
StreamOrDevice s_) {
validate_output_shapes(output_shapes, output_dtypes);
auto s = to_stream(s_);
if (s.device != Device::gpu) {
throw std::invalid_argument(
"[metal_kernel] MetalKernel only works on GPU.");
}
std::ostringstream kernel_source;
std::ostringstream func_name;
std::string template_def = "";
bool needs_template = template_args && template_args.value().size() > 0;
std::string hash_key = "";
if (needs_template) {
std::regex disallowed_chars("\\<|\\>|(, )");
template_def = write_template(template_args.value());
hash_key = std::regex_replace(template_def, disallowed_chars, "_");
hash_key.pop_back();
}
func_name << "custom_kernel_" << name_ << hash_key;
std::string kernel_name = func_name.str();
std::vector<CustomKernelShapeInfo> shape_infos;
write_signature(
func_name.str(),
source_,
inputs,
output_shapes,
output_dtypes,
template_args,
shape_infos,
kernel_source);
if (needs_template) {
template_def = func_name.str() + template_def;
kernel_source << std::endl
<< "template [[host_name(\"" << kernel_name
<< "\")]] [[kernel]] decltype(" << template_def << ") "
<< template_def << ";" << std::endl;
}
if (verbose) {
std::cout << "Generated source code for `" << name_ << "`:" << std::endl
<< "```" << std::endl
<< kernel_source.str() << std::endl
<< "```" << std::endl;
}
std::vector<array> in_arrs;
for (const auto& kv : inputs) {
in_arrs.push_back(kv.second);
}
std::vector<std::string> out_keys;
std::vector<std::vector<int>> out_shapes;
for (const auto& [name, shape] : output_shapes) {
out_keys.push_back(name);
out_shapes.push_back(shape);
}
std::vector<Dtype> out_dtypes;
for (const auto& kv : output_dtypes) {
out_dtypes.push_back(kv.second);
}
std::map<std::string, array> outputs;
auto outputs_vec = array::make_arrays(
out_shapes,
out_dtypes,
std::make_shared<CustomKernel>(
s,
kernel_name,
kernel_source.str(),
grid,
threadgroup,
shape_infos,
ensure_row_contiguous_),
in_arrs);
int i = 0;
for (const auto& key : out_keys) {
outputs.insert({key, outputs_vec[i]});
i++;
}
return outputs;
}
} // namespace mlx::core::fast

View File

@@ -2,6 +2,7 @@
#pragma once
#include <map>
#include <optional>
#include "mlx/utils.h"
@@ -63,4 +64,32 @@ array affine_dequantize(
int bits = 4,
StreamOrDevice s = {});
typedef std::variant<int, bool, Dtype> TemplateArg;
class MetalKernel {
public:
MetalKernel(
const std::string& name,
const std::string& source,
bool ensure_row_contiguous)
: name_(name),
source_(source),
ensure_row_contiguous_(ensure_row_contiguous) {}
std::map<std::string, array> operator()(
std::map<std::string, array>& inputs,
std::map<std::string, std::vector<int>> output_shapes,
std::map<std::string, Dtype> output_dtypes,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
std::optional<std::map<std::string, TemplateArg>> template_args =
std::nullopt,
bool verbose = false,
StreamOrDevice s = {});
private:
std::string name_;
std::string source_;
bool ensure_row_contiguous_ = true;
};
} // namespace mlx::core::fast

View File

@@ -242,4 +242,47 @@ class AffineQuantize : public Custom {
bool dequantize_;
};
struct CustomKernelShapeInfo {
bool shape = false;
bool strides = false;
bool ndim = false;
};
class CustomKernel : public Primitive {
public:
CustomKernel(
Stream stream,
std::string name,
std::string source,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
std::vector<CustomKernelShapeInfo> shape_infos,
bool ensure_row_contiguous)
: Primitive(stream),
source_(source),
name_(name),
grid_(grid),
threadgroup_(threadgroup),
shape_infos_(shape_infos),
ensure_row_contiguous_(ensure_row_contiguous) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
throw std::runtime_error("Custom Metal kernels only run on GPU.");
}
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
DEFINE_PRINT(CustomKernel);
private:
std::string source_;
std::string name_;
std::tuple<int, int, int> grid_;
std::tuple<int, int, int> threadgroup_;
std::vector<CustomKernelShapeInfo> shape_infos_;
bool ensure_row_contiguous_;
};
} // namespace mlx::core::fast