export custom kernel (#2756)

This commit is contained in:
Awni Hannun
2025-11-13 11:29:50 -08:00
committed by GitHub
parent 3f866be665
commit 8973550ff3
6 changed files with 161 additions and 37 deletions

View File

@@ -57,7 +57,7 @@ std::string build_kernel(
const std::vector<std::string>& output_names,
const std::vector<Dtype>& output_dtypes,
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
const std::vector<CustomKernelShapeInfo>& shape_infos) {
const std::vector<std::tuple<bool, bool, bool>>& shape_infos) {
std::string kernel_source;
kernel_source.reserve(header.size() + source.size() + 8192);
kernel_source += default_header;
@@ -81,17 +81,17 @@ std::string build_kernel(
kernel_source += ",\n";
// Add input shape, strides and ndim if present in the source
if (arr.ndim() > 0) {
if (shape_infos[i].shape) {
if (std::get<0>(shape_infos[i])) {
kernel_source += " const __grid_constant__ Shape ";
kernel_source += name;
kernel_source += "_shape,\n";
}
if (shape_infos[i].strides) {
if (std::get<1>(shape_infos[i])) {
kernel_source += " const __grid_constant__ Strides ";
kernel_source += name;
kernel_source += "_strides,\n";
}
if (shape_infos[i].ndim) {
if (std::get<2>(shape_infos[i])) {
kernel_source += " const __grid_constant__ int ";
kernel_source += name;
kernel_source += "_ndim,\n";
@@ -154,12 +154,12 @@ CustomKernelFunction cuda_kernel(
"[custom_kernel] Must specify at least one output.");
}
std::vector<CustomKernelShapeInfo> shape_infos;
std::vector<std::tuple<bool, bool, bool>> shape_infos;
for (auto& n : input_names) {
CustomKernelShapeInfo shape_info;
shape_info.shape = source.find(n + "_shape") != std::string::npos;
shape_info.strides = source.find(n + "_strides") != std::string::npos;
shape_info.ndim = source.find(n + "_ndim") != std::string::npos;
std::tuple<bool, bool, bool> shape_info;
std::get<0>(shape_info) = source.find(n + "_shape") != std::string::npos;
std::get<1>(shape_info) = source.find(n + "_strides") != std::string::npos;
std::get<2>(shape_info) = source.find(n + "_ndim") != std::string::npos;
shape_infos.push_back(shape_info);
}
@@ -254,8 +254,8 @@ std::vector<array> precompiled_cuda_kernel(
std::optional<float> init_value,
bool ensure_row_contiguous,
StreamOrDevice s) {
std::vector<CustomKernelShapeInfo> shape_infos(
inputs.size(), CustomKernelShapeInfo{false, false, false});
std::vector<std::tuple<bool, bool, bool>> shape_infos(
inputs.size(), {false, false, false});
return array::make_arrays(
output_shapes,
output_dtypes,
@@ -327,13 +327,13 @@ void CustomKernel::eval_gpu(
const array& in = checked_inputs[i];
auto& shape_info = shape_infos_[i];
args.append(in);
if (shape_info.shape) {
if (std::get<0>(shape_info)) {
args.append_ndim(in.shape());
}
if (shape_info.strides) {
if (std::get<1>(shape_info)) {
args.append_ndim(in.strides());
}
if (shape_info.ndim) {
if (std::get<2>(shape_info)) {
args.append<int32_t>(in.ndim());
}
}