fix large ops (#1620)

This commit is contained in:
Awni Hannun
2024-11-24 09:17:10 -08:00
committed by GitHub
parent bb303c45a5
commit 211411faf2
12 changed files with 37 additions and 25 deletions

View File

@@ -40,6 +40,9 @@ void ternary_op_gpu_inplace(
auto ndim = shape.size();
int work_per_thread;
if (topt == TernaryOpType::General) {
large |=
(a.data_size() > UINT32_MAX || b.data_size() > UINT32_MAX ||
c.data_size() > UINT32_MAX);
work_per_thread = large ? 4 : 2;
} else {
work_per_thread = 1;