mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Implement the 'where' primitive for conditional selection (#664)
This commit is contained in:
@@ -9,7 +9,7 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
enum BinaryOpType {
|
||||
enum class BinaryOpType {
|
||||
ScalarScalar,
|
||||
ScalarVector,
|
||||
VectorScalar,
|
||||
@@ -20,17 +20,17 @@ enum BinaryOpType {
|
||||
BinaryOpType get_binary_op_type(const array& a, const array& b) {
|
||||
BinaryOpType bopt;
|
||||
if (a.data_size() == 1 && b.data_size() == 1) {
|
||||
bopt = ScalarScalar;
|
||||
bopt = BinaryOpType::ScalarScalar;
|
||||
} else if (a.data_size() == 1 && b.flags().contiguous) {
|
||||
bopt = ScalarVector;
|
||||
bopt = BinaryOpType::ScalarVector;
|
||||
} else if (b.data_size() == 1 && a.flags().contiguous) {
|
||||
bopt = VectorScalar;
|
||||
bopt = BinaryOpType::VectorScalar;
|
||||
} else if (
|
||||
a.flags().row_contiguous && b.flags().row_contiguous ||
|
||||
a.flags().col_contiguous && b.flags().col_contiguous) {
|
||||
bopt = VectorVector;
|
||||
bopt = BinaryOpType::VectorVector;
|
||||
} else {
|
||||
bopt = General;
|
||||
bopt = BinaryOpType::General;
|
||||
}
|
||||
return bopt;
|
||||
}
|
||||
@@ -42,11 +42,11 @@ void set_binary_op_output_data(
|
||||
BinaryOpType bopt,
|
||||
bool donate_with_move = false) {
|
||||
switch (bopt) {
|
||||
case ScalarScalar:
|
||||
case BinaryOpType::ScalarScalar:
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
|
||||
break;
|
||||
case ScalarVector:
|
||||
case BinaryOpType::ScalarVector:
|
||||
if (b.is_donatable() && b.itemsize() == out.itemsize()) {
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(b);
|
||||
@@ -61,7 +61,7 @@ void set_binary_op_output_data(
|
||||
b.flags());
|
||||
}
|
||||
break;
|
||||
case VectorScalar:
|
||||
case BinaryOpType::VectorScalar:
|
||||
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(a);
|
||||
@@ -76,7 +76,7 @@ void set_binary_op_output_data(
|
||||
a.flags());
|
||||
}
|
||||
break;
|
||||
case VectorVector:
|
||||
case BinaryOpType::VectorVector:
|
||||
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(a);
|
||||
@@ -97,7 +97,7 @@ void set_binary_op_output_data(
|
||||
a.flags());
|
||||
}
|
||||
break;
|
||||
case General:
|
||||
case BinaryOpType::General:
|
||||
if (a.is_donatable() && a.flags().row_contiguous &&
|
||||
a.itemsize() == out.itemsize() && a.size() == out.size()) {
|
||||
if (donate_with_move) {
|
||||
@@ -424,25 +424,25 @@ void binary_op(
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
// The full computation is scalar scalar so call the base op once
|
||||
if (bopt == ScalarScalar) {
|
||||
if (bopt == BinaryOpType::ScalarScalar) {
|
||||
*(out.data<U>()) = op(*a.data<T>(), *b.data<T>());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is scalar vector so delegate to the op
|
||||
if (bopt == ScalarVector) {
|
||||
if (bopt == BinaryOpType::ScalarVector) {
|
||||
opsv(a.data<T>(), b.data<T>(), out.data<U>(), b.data_size());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is vector scalar so delegate to the op
|
||||
if (bopt == VectorScalar) {
|
||||
if (bopt == BinaryOpType::VectorScalar) {
|
||||
opvs(a.data<T>(), b.data<T>(), out.data<U>(), a.data_size());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is vector vector so delegate to the op
|
||||
if (bopt == VectorVector) {
|
||||
if (bopt == BinaryOpType::VectorVector) {
|
||||
opvv(a.data<T>(), b.data<T>(), out.data<U>(), out.size());
|
||||
return;
|
||||
}
|
||||
@@ -475,17 +475,17 @@ void binary_op(
|
||||
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
|
||||
int dim = ndim;
|
||||
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
|
||||
bopt = VectorVector;
|
||||
bopt = BinaryOpType::VectorVector;
|
||||
dim = d;
|
||||
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
|
||||
// contiguous
|
||||
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
|
||||
bopt = VectorScalar;
|
||||
bopt = BinaryOpType::VectorScalar;
|
||||
dim = d;
|
||||
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
|
||||
// contiguous
|
||||
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
|
||||
bopt = ScalarVector;
|
||||
bopt = BinaryOpType::ScalarVector;
|
||||
dim = d;
|
||||
}
|
||||
|
||||
@@ -495,20 +495,20 @@ void binary_op(
|
||||
size_t stride;
|
||||
if (dim == 0 || strides[dim - 1] < 16) {
|
||||
stride = 1;
|
||||
bopt = General;
|
||||
bopt = BinaryOpType::General;
|
||||
dim = ndim;
|
||||
} else {
|
||||
stride = strides[dim - 1];
|
||||
}
|
||||
|
||||
switch (bopt) {
|
||||
case VectorVector:
|
||||
case BinaryOpType::VectorVector:
|
||||
binary_op_dispatch_dims<T, U>(a, b, out, opvv, dim, stride);
|
||||
break;
|
||||
case VectorScalar:
|
||||
case BinaryOpType::VectorScalar:
|
||||
binary_op_dispatch_dims<T, U>(a, b, out, opvs, dim, stride);
|
||||
break;
|
||||
case ScalarVector:
|
||||
case BinaryOpType::ScalarVector:
|
||||
binary_op_dispatch_dims<T, U>(a, b, out, opsv, dim, stride);
|
||||
break;
|
||||
default:
|
||||
|
||||
Reference in New Issue
Block a user