Fix ternary for large arrays (#1359)

* fix ternary for large arrays

* fix
This commit is contained in:
Awni Hannun
2024-08-26 11:22:27 -07:00
committed by GitHub
parent 860d3a50d7
commit 2fdf9eb535
2 changed files with 37 additions and 5 deletions

View File

@@ -12,6 +12,7 @@ namespace {
// TODO: Add support for more combinations of input types.
enum class TernaryOpType {
ScalarScalarScalar,
VectorVectorVector,
General,
};
@@ -20,6 +21,12 @@ get_ternary_op_type(const array& a, const array& b, const array& c) {
TernaryOpType topt;
if (a.data_size() == 1 && b.data_size() == 1 && c.data_size() == 1) {
topt = TernaryOpType::ScalarScalarScalar;
} else if (
(a.flags().row_contiguous && b.flags().row_contiguous &&
c.flags().row_contiguous) ||
(a.flags().col_contiguous && b.flags().col_contiguous &&
c.flags().col_contiguous)) {
topt = TernaryOpType::VectorVectorVector;
} else {
topt = TernaryOpType::General;
}
@@ -33,11 +40,32 @@ void set_ternary_op_output_data(
array& out,
TernaryOpType topt,
bool donate_with_move = false) {
auto maybe_donate = [&out, donate_with_move](const array& x) {
if (x.is_donatable() && x.itemsize() == out.itemsize()) {
if (donate_with_move) {
out.move_shared_buffer(x);
} else {
out.copy_shared_buffer(x);
}
return true;
}
return false;
};
switch (topt) {
case TernaryOpType::ScalarScalarScalar:
out.set_data(
allocator::malloc_or_wait(out.itemsize()), 1, b.strides(), b.flags());
break;
case TernaryOpType::VectorVectorVector:
if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
out.set_data(
allocator::malloc_or_wait(out.itemsize() * b.data_size()),
b.data_size(),
b.strides(),
b.flags());
}
break;
case TernaryOpType::General:
out.set_data(allocator::malloc_or_wait(out.nbytes()));
break;