mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 10:48:09 +08:00
Add optional headers to `mx.fast.metal_kernel
` (#1358)
This commit is contained in:
@@ -1112,7 +1112,6 @@ std::map<std::string, array> MetalKernel::operator()(
|
||||
"[metal_kernel] MetalKernel only works on GPU.");
|
||||
}
|
||||
|
||||
std::ostringstream kernel_source;
|
||||
std::ostringstream func_name;
|
||||
|
||||
std::string template_def = "";
|
||||
@@ -1128,6 +1127,9 @@ std::map<std::string, array> MetalKernel::operator()(
|
||||
func_name << "custom_kernel_" << name_ << hash_key;
|
||||
std::string kernel_name = func_name.str();
|
||||
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << header_ << std::endl;
|
||||
|
||||
std::vector<CustomKernelShapeInfo> shape_infos;
|
||||
write_signature(
|
||||
func_name.str(),
|
||||
|
11
mlx/fast.h
11
mlx/fast.h
@@ -71,10 +71,12 @@ class MetalKernel {
|
||||
MetalKernel(
|
||||
const std::string& name,
|
||||
const std::string& source,
|
||||
bool ensure_row_contiguous,
|
||||
bool atomic_outputs)
|
||||
const std::string& header = "",
|
||||
bool ensure_row_contiguous = true,
|
||||
bool atomic_outputs = false)
|
||||
: name_(name),
|
||||
source_(source),
|
||||
header_(header),
|
||||
ensure_row_contiguous_(ensure_row_contiguous),
|
||||
atomic_outputs_(atomic_outputs) {}
|
||||
|
||||
@@ -93,7 +95,8 @@ class MetalKernel {
|
||||
private:
|
||||
std::string name_;
|
||||
std::string source_;
|
||||
bool ensure_row_contiguous_ = true;
|
||||
bool atomic_outputs_ = false;
|
||||
std::string header_;
|
||||
bool ensure_row_contiguous_;
|
||||
bool atomic_outputs_;
|
||||
};
|
||||
} // namespace mlx::core::fast
|
||||
|
Reference in New Issue
Block a user