mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add optional headers to `mx.fast.metal_kernel` (#1358)
This commit is contained in:
@@ -1112,7 +1112,6 @@ std::map<std::string, array> MetalKernel::operator()(
|
||||
"[metal_kernel] MetalKernel only works on GPU.");
|
||||
}
|
||||
|
||||
std::ostringstream kernel_source;
|
||||
std::ostringstream func_name;
|
||||
|
||||
std::string template_def = "";
|
||||
@@ -1128,6 +1127,9 @@ std::map<std::string, array> MetalKernel::operator()(
|
||||
func_name << "custom_kernel_" << name_ << hash_key;
|
||||
std::string kernel_name = func_name.str();
|
||||
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << header_ << std::endl;
|
||||
|
||||
std::vector<CustomKernelShapeInfo> shape_infos;
|
||||
write_signature(
|
||||
func_name.str(),
|
||||
|
||||
Reference in New Issue
Block a user