Add optional headers to `mx.fast.metal_kernel` (#1358)

This commit is contained in:
Alex Barron
2024-08-26 21:45:45 -07:00
committed by GitHub
parent 5f7d19d1f5
commit 1d94ac3f90
4 changed files with 50 additions and 10 deletions

View File

@@ -1112,7 +1112,6 @@ std::map<std::string, array> MetalKernel::operator()(
"[metal_kernel] MetalKernel only works on GPU.");
}
std::ostringstream kernel_source;
std::ostringstream func_name;
std::string template_def = "";
@@ -1128,6 +1127,9 @@ std::map<std::string, array> MetalKernel::operator()(
func_name << "custom_kernel_" << name_ << hash_key;
std::string kernel_name = func_name.str();
std::ostringstream kernel_source;
kernel_source << header_ << std::endl;
std::vector<CustomKernelShapeInfo> shape_infos;
write_signature(
func_name.str(),