mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
fix large ops (#1620)
This commit is contained in:
@@ -40,6 +40,9 @@ void ternary_op_gpu_inplace(
|
||||
auto ndim = shape.size();
|
||||
int work_per_thread;
|
||||
if (topt == TernaryOpType::General) {
|
||||
large |=
|
||||
(a.data_size() > UINT32_MAX || b.data_size() > UINT32_MAX ||
|
||||
c.data_size() > UINT32_MAX);
|
||||
work_per_thread = large ? 4 : 2;
|
||||
} else {
|
||||
work_per_thread = 1;
|
||||
|
||||
Reference in New Issue
Block a user