Allow no copy negative strides in as_strided and slice (#1688)

* allow no copy negative strides in as_strided and slice

* fix jit

* fix jit
This commit is contained in:
Awni Hannun
2024-12-12 08:59:45 -08:00
committed by GitHub
parent 4d595a2a39
commit 6bd28d246e
15 changed files with 133 additions and 163 deletions

View File

@@ -36,15 +36,15 @@ void ternary_op_gpu_inplace(
};
auto [shape, strides_a, strides_b, strides_c, strides_out] = maybe_collapse();
bool large = out.data_size() > UINT_MAX;
bool large;
auto ndim = shape.size();
int work_per_thread;
if (topt == TernaryOpType::General) {
large |=
(a.data_size() > UINT32_MAX || b.data_size() > UINT32_MAX ||
c.data_size() > UINT32_MAX);
large = a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
c.data_size() > INT32_MAX || out.size() > INT32_MAX;
work_per_thread = large ? 4 : 2;
} else {
large = out.data_size() > INT32_MAX;
work_per_thread = 1;
}
std::string kernel_name;