mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Improve metal elementwise kernels (#2247)
* improve metal elementwise kernels * compile and copy * fix jit
This commit is contained in:
@@ -45,7 +45,7 @@ void ternary_op_gpu_inplace(
|
||||
work_per_thread = large ? 4 : 2;
|
||||
} else {
|
||||
large = out.data_size() > INT32_MAX;
|
||||
work_per_thread = get_work_per_thread(b.dtype());
|
||||
work_per_thread = get_work_per_thread(b.dtype(), out.data_size());
|
||||
}
|
||||
std::string kernel_name;
|
||||
if (topt == TernaryOpType::General) {
|
||||
@@ -60,6 +60,8 @@ void ternary_op_gpu_inplace(
|
||||
}
|
||||
} else if (large) {
|
||||
kernel_name = "v2";
|
||||
} else if (work_per_thread > 1) {
|
||||
kernel_name = "vn";
|
||||
} else {
|
||||
kernel_name = "v";
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user