Add optional headers to `mx.fast.metal_kernel` (#1358)

This commit is contained in:
Alex Barron
2024-08-26 21:45:45 -07:00
committed by GitHub
parent 5f7d19d1f5
commit 1d94ac3f90
4 changed files with 50 additions and 10 deletions

View File

@@ -1112,7 +1112,6 @@ std::map<std::string, array> MetalKernel::operator()(
"[metal_kernel] MetalKernel only works on GPU.");
}
std::ostringstream kernel_source;
std::ostringstream func_name;
std::string template_def = "";
@@ -1128,6 +1127,9 @@ std::map<std::string, array> MetalKernel::operator()(
func_name << "custom_kernel_" << name_ << hash_key;
std::string kernel_name = func_name.str();
std::ostringstream kernel_source;
kernel_source << header_ << std::endl;
std::vector<CustomKernelShapeInfo> shape_infos;
write_signature(
func_name.str(),

View File

@@ -71,10 +71,12 @@ class MetalKernel {
MetalKernel(
const std::string& name,
const std::string& source,
bool ensure_row_contiguous,
bool atomic_outputs)
const std::string& header = "",
bool ensure_row_contiguous = true,
bool atomic_outputs = false)
: name_(name),
source_(source),
header_(header),
ensure_row_contiguous_(ensure_row_contiguous),
atomic_outputs_(atomic_outputs) {}
@@ -93,7 +95,8 @@ class MetalKernel {
private:
std::string name_;
std::string source_;
bool ensure_row_contiguous_ = true;
bool atomic_outputs_ = false;
std::string header_;
bool ensure_row_contiguous_;
bool atomic_outputs_;
};
} // namespace mlx::core::fast