export custom kernel (#2756)

This commit is contained in:
Awni Hannun
2025-11-13 11:29:50 -08:00
committed by GitHub
parent 3f866be665
commit 8973550ff3
6 changed files with 161 additions and 37 deletions

View File

@@ -57,7 +57,7 @@ std::string build_kernel(
const std::vector<std::string>& output_names,
const std::vector<Dtype>& output_dtypes,
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
const std::vector<CustomKernelShapeInfo>& shape_infos) {
const std::vector<std::tuple<bool, bool, bool>>& shape_infos) {
std::string kernel_source;
kernel_source.reserve(header.size() + source.size() + 8192);
kernel_source += default_header;
@@ -81,17 +81,17 @@ std::string build_kernel(
kernel_source += ",\n";
// Add input shape, strides and ndim if present in the source
if (arr.ndim() > 0) {
if (shape_infos[i].shape) {
if (std::get<0>(shape_infos[i])) {
kernel_source += " const __grid_constant__ Shape ";
kernel_source += name;
kernel_source += "_shape,\n";
}
if (shape_infos[i].strides) {
if (std::get<1>(shape_infos[i])) {
kernel_source += " const __grid_constant__ Strides ";
kernel_source += name;
kernel_source += "_strides,\n";
}
if (shape_infos[i].ndim) {
if (std::get<2>(shape_infos[i])) {
kernel_source += " const __grid_constant__ int ";
kernel_source += name;
kernel_source += "_ndim,\n";
@@ -154,12 +154,12 @@ CustomKernelFunction cuda_kernel(
"[custom_kernel] Must specify at least one output.");
}
std::vector<CustomKernelShapeInfo> shape_infos;
std::vector<std::tuple<bool, bool, bool>> shape_infos;
for (auto& n : input_names) {
CustomKernelShapeInfo shape_info;
shape_info.shape = source.find(n + "_shape") != std::string::npos;
shape_info.strides = source.find(n + "_strides") != std::string::npos;
shape_info.ndim = source.find(n + "_ndim") != std::string::npos;
std::tuple<bool, bool, bool> shape_info;
std::get<0>(shape_info) = source.find(n + "_shape") != std::string::npos;
std::get<1>(shape_info) = source.find(n + "_strides") != std::string::npos;
std::get<2>(shape_info) = source.find(n + "_ndim") != std::string::npos;
shape_infos.push_back(shape_info);
}
@@ -254,8 +254,8 @@ std::vector<array> precompiled_cuda_kernel(
std::optional<float> init_value,
bool ensure_row_contiguous,
StreamOrDevice s) {
std::vector<CustomKernelShapeInfo> shape_infos(
inputs.size(), CustomKernelShapeInfo{false, false, false});
std::vector<std::tuple<bool, bool, bool>> shape_infos(
inputs.size(), {false, false, false});
return array::make_arrays(
output_shapes,
output_dtypes,
@@ -327,13 +327,13 @@ void CustomKernel::eval_gpu(
const array& in = checked_inputs[i];
auto& shape_info = shape_infos_[i];
args.append(in);
if (shape_info.shape) {
if (std::get<0>(shape_info)) {
args.append_ndim(in.shape());
}
if (shape_info.strides) {
if (std::get<1>(shape_info)) {
args.append_ndim(in.strides());
}
if (shape_info.ndim) {
if (std::get<2>(shape_info)) {
args.append<int32_t>(in.ndim());
}
}

View File

@@ -32,7 +32,7 @@ std::string write_signature(
const std::vector<Dtype>& output_dtypes,
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
const std::vector<std::string>& attributes,
const std::vector<CustomKernelShapeInfo>& shape_infos,
const std::vector<std::tuple<bool, bool, bool>>& shape_infos,
bool atomic_outputs) {
std::string kernel_source;
kernel_source.reserve(header.size() + source.size() + 16384);
@@ -88,19 +88,19 @@ std::string write_signature(
index++;
// Add input shape, strides and ndim if present in the source
if (arr.ndim() > 0) {
if (shape_infos[i].shape) {
if (std::get<0>(shape_infos[i])) {
kernel_source +=
(" const constant int* " + name + "_shape [[buffer(" +
std::to_string(index) + ")]],\n");
index++;
}
if (shape_infos[i].strides) {
if (std::get<1>(shape_infos[i])) {
kernel_source +=
(" const constant int64_t* " + name + "_strides [[buffer(" +
std::to_string(index) + ")]],\n");
index++;
}
if (shape_infos[i].ndim) {
if (std::get<2>(shape_infos[i])) {
kernel_source +=
(" const constant int& " + name + "_ndim [[buffer(" +
std::to_string(index) + ")]],\n");
@@ -184,12 +184,12 @@ CustomKernelFunction metal_kernel(
throw std::invalid_argument(
"[metal_kernel] Must specify at least one output.");
}
std::vector<CustomKernelShapeInfo> shape_infos;
std::vector<std::tuple<bool, bool, bool>> shape_infos;
for (auto& n : input_names) {
CustomKernelShapeInfo shape_info;
shape_info.shape = source.find(n + "_shape") != std::string::npos;
shape_info.strides = source.find(n + "_strides") != std::string::npos;
shape_info.ndim = source.find(n + "_ndim") != std::string::npos;
std::tuple<bool, bool, bool> shape_info;
std::get<0>(shape_info) = source.find(n + "_shape") != std::string::npos;
std::get<1>(shape_info) = source.find(n + "_strides") != std::string::npos;
std::get<2>(shape_info) = source.find(n + "_ndim") != std::string::npos;
shape_infos.push_back(shape_info);
}
const std::vector<std::pair<std::string, std::string>> metal_attributes = {
@@ -388,15 +388,15 @@ void CustomKernel::eval_gpu(
index++;
if (in.ndim() > 0) {
int ndim = in.ndim();
if (shape_info.shape) {
if (std::get<0>(shape_info)) {
compute_encoder.set_vector_bytes(in.shape(), ndim, index);
index++;
}
if (shape_info.strides) {
if (std::get<1>(shape_info)) {
compute_encoder.set_vector_bytes(in.strides(), ndim, index);
index++;
}
if (shape_info.ndim) {
if (std::get<2>(shape_info)) {
compute_encoder.set_bytes(ndim, index);
index++;
}

View File

@@ -75,6 +75,14 @@ constexpr bool is_pair = is_specialization_of<std::pair, std::decay_t<T>>;
template <typename T>
constexpr bool is_tuple = is_specialization_of<std::tuple, std::decay_t<T>>;
template <typename T>
inline constexpr bool is_optional =
is_specialization_of<std::optional, std::decay_t<T>>;
template <typename T>
inline constexpr bool is_variant =
is_specialization_of<std::variant, std::decay_t<T>>;
template <typename>
constexpr bool dependent_false = false;
@@ -96,6 +104,12 @@ void reverse_bytes(T& data) {
}
}
template <typename T>
void serialize_variant(Writer& os, T v);
template <typename T>
T deserialize_variant(Reader& is);
template <typename T>
void serialize(Writer& os, T v) {
if constexpr (std::is_arithmetic_v<T>) {
@@ -113,6 +127,13 @@ void serialize(Writer& os, T v) {
}
} else if constexpr (is_pair<T> || is_tuple<T>) {
std::apply([&os](auto&... x) { (..., serialize(os, x)); }, v);
} else if constexpr (is_variant<T>) {
serialize_variant(os, v);
} else if constexpr (is_optional<T>) {
serialize(os, v.has_value());
if (v.has_value()) {
serialize(os, *v);
}
} else {
NotSerializable<T>();
}
@@ -145,11 +166,58 @@ T deserialize(Reader& is) {
} else if constexpr (is_pair<T> || is_tuple<T>) {
return deserialize_tuple<T>(
is, std::make_index_sequence<std::tuple_size_v<std::decay_t<T>>>{});
} else if constexpr (is_optional<T>) {
auto has_value = deserialize<bool>(is);
if (has_value) {
return deserialize<T>(is);
} else {
return std::nullopt;
}
} else if constexpr (is_variant<T>) {
return deserialize_variant<T>(is);
} else {
NotDeserializable<T>();
}
}
enum class VariantType { Int = 0, Float = 1, Bool = 2 };
template <typename T>
void serialize_variant(Writer& os, T v) {
std::visit(
[&](auto&& x) {
using ElemT = std::decay_t<decltype(x)>;
if constexpr (std::is_same_v<ElemT, int>) {
serialize(os, VariantType::Int);
} else if constexpr (std::is_same_v<ElemT, float>) {
serialize(os, VariantType::Float);
} else if constexpr (std::is_same_v<ElemT, bool>) {
serialize(os, VariantType::Bool);
} else {
static_assert(
std::is_same_v<ElemT, void>, "Can't serialize variant type.");
}
serialize(os, x);
},
v);
}
template <typename T>
T deserialize_variant(Reader& is) {
auto vt = deserialize<VariantType>(is);
switch (vt) {
case VariantType::Int:
return deserialize<int>(is);
case VariantType::Float:
return deserialize<float>(is);
case VariantType::Bool:
return deserialize<bool>(is);
default:
throw std::runtime_error(
"[deserialize_variant] Unknonw variant type tag.");
}
}
template <typename T, std::size_t... I>
decltype(auto) deserialize_tuple(Reader& is, std::index_sequence<I...>) {
return T{deserialize<std::tuple_element_t<I, T>>(is)...};
@@ -374,7 +442,8 @@ struct PrimitiveFactory {
SERIALIZE_PRIMITIVE(LayerNorm),
SERIALIZE_PRIMITIVE(LayerNormVJP),
SERIALIZE_PRIMITIVE(RoPE),
SERIALIZE_PRIMITIVE(ScaledDotProductAttention)};
SERIALIZE_PRIMITIVE(ScaledDotProductAttention),
SERIALIZE_PRIMITIVE(CustomKernel)};
std::unordered_map<std::string, std::string> name_remap;
PrimitiveFactory() {

View File

@@ -2,6 +2,7 @@
#pragma once
#include <optional>
#include <set>
#include <unordered_map>
#include <variant>
@@ -24,6 +25,9 @@ using StateT = std::variant<
Strides,
std::vector<int>,
std::vector<size_t>,
std::vector<std::tuple<bool, bool, bool>>,
std::vector<std::variant<bool, int, float>>,
std::optional<float>,
std::string>;
using ExportCallbackInput = std::unordered_map<

View File

@@ -315,12 +315,6 @@ class Quantize : public Custom {
bool dequantize_;
};
struct CustomKernelShapeInfo {
bool shape = false;
bool strides = false;
bool ndim = false;
};
using ScalarArg = std::variant<bool, int, float>;
class CustomKernel : public Primitive {
@@ -331,15 +325,15 @@ class CustomKernel : public Primitive {
std::string source,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
std::vector<CustomKernelShapeInfo> shape_infos,
std::vector<std::tuple<bool, bool, bool>> shape_infos,
bool ensure_row_contiguous,
std::optional<float> init_value,
std::vector<ScalarArg> scalar_arguments,
bool is_precompiled,
int shared_memory)
: Primitive(stream),
source_(std::move(source)),
name_(std::move(name)),
source_(std::move(source)),
grid_(grid),
threadgroup_(threadgroup),
shape_infos_(std::move(shape_infos)),
@@ -358,13 +352,26 @@ class CustomKernel : public Primitive {
override;
DEFINE_NAME(CustomKernel);
auto state() const {
return std::make_tuple(
name_,
source_,
grid_,
threadgroup_,
shape_infos_,
ensure_row_contiguous_,
init_value_,
scalar_arguments_,
is_precompiled_,
shared_memory_);
}
private:
std::string source_;
std::string name_;
std::string source_;
std::tuple<int, int, int> grid_;
std::tuple<int, int, int> threadgroup_;
std::vector<CustomKernelShapeInfo> shape_infos_;
std::vector<std::tuple<bool, bool, bool>> shape_infos_;
bool ensure_row_contiguous_;
std::optional<float> init_value_;
std::vector<ScalarArg> scalar_arguments_;