Custom cuda kernel (#2517)

This commit is contained in:
Angelos Katharopoulos
2025-08-20 17:20:22 -07:00
committed by GitHub
parent f4c8888cbe
commit e397177f6e
19 changed files with 1042 additions and 211 deletions

View File

@@ -172,7 +172,7 @@ std::string write_template(
return template_def.str();
}
MetalKernelFunction metal_kernel(
CustomKernelFunction metal_kernel(
const std::string& name,
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
@@ -316,7 +316,10 @@ MetalKernelFunction metal_kernel(
threadgroup,
shape_infos,
ensure_row_contiguous,
init_value),
init_value,
std::vector<ScalarArg>{},
false,
0),
std::move(inputs));
};
}

View File

@@ -26,15 +26,15 @@ device_info() {
namespace fast {
MetalKernelFunction metal_kernel(
CustomKernelFunction metal_kernel(
const std::string&,
const std::vector<std::string>&,
const std::vector<std::string>&,
const std::string&,
const std::string&,
bool ensure_row_contiguous,
bool atomic_outputs) {
throw std::runtime_error("[metal_kernel] No GPU back-end.");
bool,
bool) {
throw std::runtime_error("[metal_kernel] No Metal back-end.");
}
} // namespace fast