mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
export custom kernel (#2756)
This commit is contained in:
@@ -57,7 +57,7 @@ std::string build_kernel(
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::vector<Dtype>& output_dtypes,
|
||||
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
|
||||
const std::vector<CustomKernelShapeInfo>& shape_infos) {
|
||||
const std::vector<std::tuple<bool, bool, bool>>& shape_infos) {
|
||||
std::string kernel_source;
|
||||
kernel_source.reserve(header.size() + source.size() + 8192);
|
||||
kernel_source += default_header;
|
||||
@@ -81,17 +81,17 @@ std::string build_kernel(
|
||||
kernel_source += ",\n";
|
||||
// Add input shape, strides and ndim if present in the source
|
||||
if (arr.ndim() > 0) {
|
||||
if (shape_infos[i].shape) {
|
||||
if (std::get<0>(shape_infos[i])) {
|
||||
kernel_source += " const __grid_constant__ Shape ";
|
||||
kernel_source += name;
|
||||
kernel_source += "_shape,\n";
|
||||
}
|
||||
if (shape_infos[i].strides) {
|
||||
if (std::get<1>(shape_infos[i])) {
|
||||
kernel_source += " const __grid_constant__ Strides ";
|
||||
kernel_source += name;
|
||||
kernel_source += "_strides,\n";
|
||||
}
|
||||
if (shape_infos[i].ndim) {
|
||||
if (std::get<2>(shape_infos[i])) {
|
||||
kernel_source += " const __grid_constant__ int ";
|
||||
kernel_source += name;
|
||||
kernel_source += "_ndim,\n";
|
||||
@@ -154,12 +154,12 @@ CustomKernelFunction cuda_kernel(
|
||||
"[custom_kernel] Must specify at least one output.");
|
||||
}
|
||||
|
||||
std::vector<CustomKernelShapeInfo> shape_infos;
|
||||
std::vector<std::tuple<bool, bool, bool>> shape_infos;
|
||||
for (auto& n : input_names) {
|
||||
CustomKernelShapeInfo shape_info;
|
||||
shape_info.shape = source.find(n + "_shape") != std::string::npos;
|
||||
shape_info.strides = source.find(n + "_strides") != std::string::npos;
|
||||
shape_info.ndim = source.find(n + "_ndim") != std::string::npos;
|
||||
std::tuple<bool, bool, bool> shape_info;
|
||||
std::get<0>(shape_info) = source.find(n + "_shape") != std::string::npos;
|
||||
std::get<1>(shape_info) = source.find(n + "_strides") != std::string::npos;
|
||||
std::get<2>(shape_info) = source.find(n + "_ndim") != std::string::npos;
|
||||
shape_infos.push_back(shape_info);
|
||||
}
|
||||
|
||||
@@ -254,8 +254,8 @@ std::vector<array> precompiled_cuda_kernel(
|
||||
std::optional<float> init_value,
|
||||
bool ensure_row_contiguous,
|
||||
StreamOrDevice s) {
|
||||
std::vector<CustomKernelShapeInfo> shape_infos(
|
||||
inputs.size(), CustomKernelShapeInfo{false, false, false});
|
||||
std::vector<std::tuple<bool, bool, bool>> shape_infos(
|
||||
inputs.size(), {false, false, false});
|
||||
return array::make_arrays(
|
||||
output_shapes,
|
||||
output_dtypes,
|
||||
@@ -327,13 +327,13 @@ void CustomKernel::eval_gpu(
|
||||
const array& in = checked_inputs[i];
|
||||
auto& shape_info = shape_infos_[i];
|
||||
args.append(in);
|
||||
if (shape_info.shape) {
|
||||
if (std::get<0>(shape_info)) {
|
||||
args.append_ndim(in.shape());
|
||||
}
|
||||
if (shape_info.strides) {
|
||||
if (std::get<1>(shape_info)) {
|
||||
args.append_ndim(in.strides());
|
||||
}
|
||||
if (shape_info.ndim) {
|
||||
if (std::get<2>(shape_info)) {
|
||||
args.append<int32_t>(in.ndim());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,7 +32,7 @@ std::string write_signature(
|
||||
const std::vector<Dtype>& output_dtypes,
|
||||
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
|
||||
const std::vector<std::string>& attributes,
|
||||
const std::vector<CustomKernelShapeInfo>& shape_infos,
|
||||
const std::vector<std::tuple<bool, bool, bool>>& shape_infos,
|
||||
bool atomic_outputs) {
|
||||
std::string kernel_source;
|
||||
kernel_source.reserve(header.size() + source.size() + 16384);
|
||||
@@ -88,19 +88,19 @@ std::string write_signature(
|
||||
index++;
|
||||
// Add input shape, strides and ndim if present in the source
|
||||
if (arr.ndim() > 0) {
|
||||
if (shape_infos[i].shape) {
|
||||
if (std::get<0>(shape_infos[i])) {
|
||||
kernel_source +=
|
||||
(" const constant int* " + name + "_shape [[buffer(" +
|
||||
std::to_string(index) + ")]],\n");
|
||||
index++;
|
||||
}
|
||||
if (shape_infos[i].strides) {
|
||||
if (std::get<1>(shape_infos[i])) {
|
||||
kernel_source +=
|
||||
(" const constant int64_t* " + name + "_strides [[buffer(" +
|
||||
std::to_string(index) + ")]],\n");
|
||||
index++;
|
||||
}
|
||||
if (shape_infos[i].ndim) {
|
||||
if (std::get<2>(shape_infos[i])) {
|
||||
kernel_source +=
|
||||
(" const constant int& " + name + "_ndim [[buffer(" +
|
||||
std::to_string(index) + ")]],\n");
|
||||
@@ -184,12 +184,12 @@ CustomKernelFunction metal_kernel(
|
||||
throw std::invalid_argument(
|
||||
"[metal_kernel] Must specify at least one output.");
|
||||
}
|
||||
std::vector<CustomKernelShapeInfo> shape_infos;
|
||||
std::vector<std::tuple<bool, bool, bool>> shape_infos;
|
||||
for (auto& n : input_names) {
|
||||
CustomKernelShapeInfo shape_info;
|
||||
shape_info.shape = source.find(n + "_shape") != std::string::npos;
|
||||
shape_info.strides = source.find(n + "_strides") != std::string::npos;
|
||||
shape_info.ndim = source.find(n + "_ndim") != std::string::npos;
|
||||
std::tuple<bool, bool, bool> shape_info;
|
||||
std::get<0>(shape_info) = source.find(n + "_shape") != std::string::npos;
|
||||
std::get<1>(shape_info) = source.find(n + "_strides") != std::string::npos;
|
||||
std::get<2>(shape_info) = source.find(n + "_ndim") != std::string::npos;
|
||||
shape_infos.push_back(shape_info);
|
||||
}
|
||||
const std::vector<std::pair<std::string, std::string>> metal_attributes = {
|
||||
@@ -388,15 +388,15 @@ void CustomKernel::eval_gpu(
|
||||
index++;
|
||||
if (in.ndim() > 0) {
|
||||
int ndim = in.ndim();
|
||||
if (shape_info.shape) {
|
||||
if (std::get<0>(shape_info)) {
|
||||
compute_encoder.set_vector_bytes(in.shape(), ndim, index);
|
||||
index++;
|
||||
}
|
||||
if (shape_info.strides) {
|
||||
if (std::get<1>(shape_info)) {
|
||||
compute_encoder.set_vector_bytes(in.strides(), ndim, index);
|
||||
index++;
|
||||
}
|
||||
if (shape_info.ndim) {
|
||||
if (std::get<2>(shape_info)) {
|
||||
compute_encoder.set_bytes(ndim, index);
|
||||
index++;
|
||||
}
|
||||
|
||||
@@ -75,6 +75,14 @@ constexpr bool is_pair = is_specialization_of<std::pair, std::decay_t<T>>;
|
||||
template <typename T>
|
||||
constexpr bool is_tuple = is_specialization_of<std::tuple, std::decay_t<T>>;
|
||||
|
||||
template <typename T>
|
||||
inline constexpr bool is_optional =
|
||||
is_specialization_of<std::optional, std::decay_t<T>>;
|
||||
|
||||
template <typename T>
|
||||
inline constexpr bool is_variant =
|
||||
is_specialization_of<std::variant, std::decay_t<T>>;
|
||||
|
||||
template <typename>
|
||||
constexpr bool dependent_false = false;
|
||||
|
||||
@@ -96,6 +104,12 @@ void reverse_bytes(T& data) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void serialize_variant(Writer& os, T v);
|
||||
|
||||
template <typename T>
|
||||
T deserialize_variant(Reader& is);
|
||||
|
||||
template <typename T>
|
||||
void serialize(Writer& os, T v) {
|
||||
if constexpr (std::is_arithmetic_v<T>) {
|
||||
@@ -113,6 +127,13 @@ void serialize(Writer& os, T v) {
|
||||
}
|
||||
} else if constexpr (is_pair<T> || is_tuple<T>) {
|
||||
std::apply([&os](auto&... x) { (..., serialize(os, x)); }, v);
|
||||
} else if constexpr (is_variant<T>) {
|
||||
serialize_variant(os, v);
|
||||
} else if constexpr (is_optional<T>) {
|
||||
serialize(os, v.has_value());
|
||||
if (v.has_value()) {
|
||||
serialize(os, *v);
|
||||
}
|
||||
} else {
|
||||
NotSerializable<T>();
|
||||
}
|
||||
@@ -145,11 +166,58 @@ T deserialize(Reader& is) {
|
||||
} else if constexpr (is_pair<T> || is_tuple<T>) {
|
||||
return deserialize_tuple<T>(
|
||||
is, std::make_index_sequence<std::tuple_size_v<std::decay_t<T>>>{});
|
||||
} else if constexpr (is_optional<T>) {
|
||||
auto has_value = deserialize<bool>(is);
|
||||
if (has_value) {
|
||||
return deserialize<T>(is);
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
} else if constexpr (is_variant<T>) {
|
||||
return deserialize_variant<T>(is);
|
||||
} else {
|
||||
NotDeserializable<T>();
|
||||
}
|
||||
}
|
||||
|
||||
enum class VariantType { Int = 0, Float = 1, Bool = 2 };
|
||||
|
||||
template <typename T>
|
||||
void serialize_variant(Writer& os, T v) {
|
||||
std::visit(
|
||||
[&](auto&& x) {
|
||||
using ElemT = std::decay_t<decltype(x)>;
|
||||
if constexpr (std::is_same_v<ElemT, int>) {
|
||||
serialize(os, VariantType::Int);
|
||||
} else if constexpr (std::is_same_v<ElemT, float>) {
|
||||
serialize(os, VariantType::Float);
|
||||
} else if constexpr (std::is_same_v<ElemT, bool>) {
|
||||
serialize(os, VariantType::Bool);
|
||||
} else {
|
||||
static_assert(
|
||||
std::is_same_v<ElemT, void>, "Can't serialize variant type.");
|
||||
}
|
||||
serialize(os, x);
|
||||
},
|
||||
v);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T deserialize_variant(Reader& is) {
|
||||
auto vt = deserialize<VariantType>(is);
|
||||
switch (vt) {
|
||||
case VariantType::Int:
|
||||
return deserialize<int>(is);
|
||||
case VariantType::Float:
|
||||
return deserialize<float>(is);
|
||||
case VariantType::Bool:
|
||||
return deserialize<bool>(is);
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[deserialize_variant] Unknonw variant type tag.");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, std::size_t... I>
|
||||
decltype(auto) deserialize_tuple(Reader& is, std::index_sequence<I...>) {
|
||||
return T{deserialize<std::tuple_element_t<I, T>>(is)...};
|
||||
@@ -374,7 +442,8 @@ struct PrimitiveFactory {
|
||||
SERIALIZE_PRIMITIVE(LayerNorm),
|
||||
SERIALIZE_PRIMITIVE(LayerNormVJP),
|
||||
SERIALIZE_PRIMITIVE(RoPE),
|
||||
SERIALIZE_PRIMITIVE(ScaledDotProductAttention)};
|
||||
SERIALIZE_PRIMITIVE(ScaledDotProductAttention),
|
||||
SERIALIZE_PRIMITIVE(CustomKernel)};
|
||||
std::unordered_map<std::string, std::string> name_remap;
|
||||
|
||||
PrimitiveFactory() {
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <optional>
|
||||
#include <set>
|
||||
#include <unordered_map>
|
||||
#include <variant>
|
||||
@@ -24,6 +25,9 @@ using StateT = std::variant<
|
||||
Strides,
|
||||
std::vector<int>,
|
||||
std::vector<size_t>,
|
||||
std::vector<std::tuple<bool, bool, bool>>,
|
||||
std::vector<std::variant<bool, int, float>>,
|
||||
std::optional<float>,
|
||||
std::string>;
|
||||
|
||||
using ExportCallbackInput = std::unordered_map<
|
||||
|
||||
@@ -315,12 +315,6 @@ class Quantize : public Custom {
|
||||
bool dequantize_;
|
||||
};
|
||||
|
||||
struct CustomKernelShapeInfo {
|
||||
bool shape = false;
|
||||
bool strides = false;
|
||||
bool ndim = false;
|
||||
};
|
||||
|
||||
using ScalarArg = std::variant<bool, int, float>;
|
||||
|
||||
class CustomKernel : public Primitive {
|
||||
@@ -331,15 +325,15 @@ class CustomKernel : public Primitive {
|
||||
std::string source,
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
std::vector<CustomKernelShapeInfo> shape_infos,
|
||||
std::vector<std::tuple<bool, bool, bool>> shape_infos,
|
||||
bool ensure_row_contiguous,
|
||||
std::optional<float> init_value,
|
||||
std::vector<ScalarArg> scalar_arguments,
|
||||
bool is_precompiled,
|
||||
int shared_memory)
|
||||
: Primitive(stream),
|
||||
source_(std::move(source)),
|
||||
name_(std::move(name)),
|
||||
source_(std::move(source)),
|
||||
grid_(grid),
|
||||
threadgroup_(threadgroup),
|
||||
shape_infos_(std::move(shape_infos)),
|
||||
@@ -358,13 +352,26 @@ class CustomKernel : public Primitive {
|
||||
override;
|
||||
|
||||
DEFINE_NAME(CustomKernel);
|
||||
auto state() const {
|
||||
return std::make_tuple(
|
||||
name_,
|
||||
source_,
|
||||
grid_,
|
||||
threadgroup_,
|
||||
shape_infos_,
|
||||
ensure_row_contiguous_,
|
||||
init_value_,
|
||||
scalar_arguments_,
|
||||
is_precompiled_,
|
||||
shared_memory_);
|
||||
}
|
||||
|
||||
private:
|
||||
std::string source_;
|
||||
std::string name_;
|
||||
std::string source_;
|
||||
std::tuple<int, int, int> grid_;
|
||||
std::tuple<int, int, int> threadgroup_;
|
||||
std::vector<CustomKernelShapeInfo> shape_infos_;
|
||||
std::vector<std::tuple<bool, bool, bool>> shape_infos_;
|
||||
bool ensure_row_contiguous_;
|
||||
std::optional<float> init_value_;
|
||||
std::vector<ScalarArg> scalar_arguments_;
|
||||
|
||||
Reference in New Issue
Block a user