Improve metal elementwise kernels (#2247)

* improve metal elementwise kernels

* compile and copy

* fix jit
This commit is contained in:
Awni Hannun
2025-06-06 11:37:40 -07:00
committed by GitHub
parent a5ac9244c4
commit c6a20b427a
17 changed files with 412 additions and 174 deletions

View File

@@ -72,6 +72,10 @@ void concatenate(std::string& acc, T first, Args... args) {
inline int get_work_per_thread(Dtype dtype) {
return std::max(1, 8 / dtype.size());
}
inline int get_work_per_thread(Dtype dtype, size_t size) {
constexpr size_t wpt_threshold = 1 << 16;
return size < wpt_threshold ? 1 : std::max(1, 8 / dtype.size());
}
inline size_t ceildiv(size_t n, size_t m) {
return (n + m - 1) / m;