Faster metal compiled kernels + some fixes (#1486)

* bump mac tests to use py39

* work per thread for compiled kernels

* fixe for large arrays

* fix
This commit is contained in:
Awni Hannun
2024-10-14 12:45:38 -07:00
committed by GitHub
parent 0eef4febfd
commit 881615b072
12 changed files with 157 additions and 108 deletions

View File

@@ -38,8 +38,7 @@ void ternary_op_gpu_inplace(
bool use_2d = out.data_size() > UINT_MAX;
auto ndim = shape.size();
int work_per_thread =
(topt == TernaryOpType::General && shape[ndim - 1] > 4) ? 4 : 1;
int work_per_thread = (topt == TernaryOpType::General) ? 4 : 1;
std::string kernel_name;
{
std::ostringstream kname;