Improve metal elementwise kernels (#2247)

* improve metal elementwise kernels

* compile and copy

* fix jit
This commit is contained in:
Awni Hannun
2025-06-06 11:37:40 -07:00
committed by GitHub
parent a5ac9244c4
commit c6a20b427a
17 changed files with 412 additions and 174 deletions

View File

@@ -45,7 +45,7 @@ void ternary_op_gpu_inplace(
work_per_thread = large ? 4 : 2;
} else {
large = out.data_size() > INT32_MAX;
work_per_thread = get_work_per_thread(b.dtype());
work_per_thread = get_work_per_thread(b.dtype(), out.data_size());
}
std::string kernel_name;
if (topt == TernaryOpType::General) {
@@ -60,6 +60,8 @@ void ternary_op_gpu_inplace(
}
} else if (large) {
kernel_name = "v2";
} else if (work_per_thread > 1) {
kernel_name = "vn";
} else {
kernel_name = "v";
}