Implement the 'where' primitive for conditional selection (#664)

This commit is contained in:
Rifur13 2024-02-22 18:10:48 -05:00 committed by GitHub
parent ad4a45e615
commit 126c9869c8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 991 additions and 56 deletions

View File

@ -73,6 +73,7 @@ void time_unary_ops() {
void time_binary_ops() {
int M = 1000, N = 100, K = 10;
auto condition = random::randint(0, 2, {M, N, K});
auto a = random::uniform({M, N, K});
auto b = random::uniform({M, N, K});
auto device = default_device();
@ -84,7 +85,9 @@ void time_binary_ops() {
TIME(divide, a, b, device);
TIME(maximum, a, b, device);
TIME(minimum, a, b, device);
TIME(where, condition, a, b, device);
condition = array({true});
b = random::uniform({1});
eval(b);
TIMEM("scalar", add, a, b, device);
@ -93,7 +96,9 @@ void time_binary_ops() {
TIMEM("scalar", multiply, a, b, device);
TIMEM("vector-scalar", divide, a, b, device);
TIMEM("scalar-vector", divide, b, a, device);
TIMEM("scalar-vector", where, condition, a, b, device);
condition = broadcast_to(array({true}), {1000, 100});
a = broadcast_to(random::uniform({1}), {1000, 100});
b = broadcast_to(random::uniform({1}), {1000, 100});
eval(a, b);
@ -101,6 +106,7 @@ void time_binary_ops() {
TIMEM("scalar-scalar broadcast", subtract, a, b, device);
TIMEM("scalar-scalar broadcast", multiply, a, b, device);
TIMEM("scalar-scalar broadcast", divide, a, b, device);
TIMEM("scalar-scalar broadcast", where, condition, a, b, device);
}
void time_strided_ops() {

View File

@ -64,6 +64,7 @@ DEFAULT(Reshape)
DEFAULT(Remainder)
DEFAULT(Round)
DEFAULT(Scatter)
DEFAULT(Select)
DEFAULT(Sigmoid)
DEFAULT(Sign)
DEFAULT(Slice)

View File

@ -43,6 +43,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp

View File

@ -9,7 +9,7 @@ namespace mlx::core {
namespace {
enum BinaryOpType {
enum class BinaryOpType {
ScalarScalar,
ScalarVector,
VectorScalar,
@ -20,17 +20,17 @@ enum BinaryOpType {
BinaryOpType get_binary_op_type(const array& a, const array& b) {
BinaryOpType bopt;
if (a.data_size() == 1 && b.data_size() == 1) {
bopt = ScalarScalar;
bopt = BinaryOpType::ScalarScalar;
} else if (a.data_size() == 1 && b.flags().contiguous) {
bopt = ScalarVector;
bopt = BinaryOpType::ScalarVector;
} else if (b.data_size() == 1 && a.flags().contiguous) {
bopt = VectorScalar;
bopt = BinaryOpType::VectorScalar;
} else if (
a.flags().row_contiguous && b.flags().row_contiguous ||
a.flags().col_contiguous && b.flags().col_contiguous) {
bopt = VectorVector;
bopt = BinaryOpType::VectorVector;
} else {
bopt = General;
bopt = BinaryOpType::General;
}
return bopt;
}
@ -42,11 +42,11 @@ void set_binary_op_output_data(
BinaryOpType bopt,
bool donate_with_move = false) {
switch (bopt) {
case ScalarScalar:
case BinaryOpType::ScalarScalar:
out.set_data(
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
break;
case ScalarVector:
case BinaryOpType::ScalarVector:
if (b.is_donatable() && b.itemsize() == out.itemsize()) {
if (donate_with_move) {
out.move_shared_buffer(b);
@ -61,7 +61,7 @@ void set_binary_op_output_data(
b.flags());
}
break;
case VectorScalar:
case BinaryOpType::VectorScalar:
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
if (donate_with_move) {
out.move_shared_buffer(a);
@ -76,7 +76,7 @@ void set_binary_op_output_data(
a.flags());
}
break;
case VectorVector:
case BinaryOpType::VectorVector:
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
if (donate_with_move) {
out.move_shared_buffer(a);
@ -97,7 +97,7 @@ void set_binary_op_output_data(
a.flags());
}
break;
case General:
case BinaryOpType::General:
if (a.is_donatable() && a.flags().row_contiguous &&
a.itemsize() == out.itemsize() && a.size() == out.size()) {
if (donate_with_move) {
@ -424,25 +424,25 @@ void binary_op(
set_binary_op_output_data(a, b, out, bopt);
// The full computation is scalar scalar so call the base op once
if (bopt == ScalarScalar) {
if (bopt == BinaryOpType::ScalarScalar) {
*(out.data<U>()) = op(*a.data<T>(), *b.data<T>());
return;
}
// The full computation is scalar vector so delegate to the op
if (bopt == ScalarVector) {
if (bopt == BinaryOpType::ScalarVector) {
opsv(a.data<T>(), b.data<T>(), out.data<U>(), b.data_size());
return;
}
// The full computation is vector scalar so delegate to the op
if (bopt == VectorScalar) {
if (bopt == BinaryOpType::VectorScalar) {
opvs(a.data<T>(), b.data<T>(), out.data<U>(), a.data_size());
return;
}
// The full computation is vector vector so delegate to the op
if (bopt == VectorVector) {
if (bopt == BinaryOpType::VectorVector) {
opvv(a.data<T>(), b.data<T>(), out.data<U>(), out.size());
return;
}
@ -475,17 +475,17 @@ void binary_op(
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
int dim = ndim;
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
bopt = VectorVector;
bopt = BinaryOpType::VectorVector;
dim = d;
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
bopt = VectorScalar;
bopt = BinaryOpType::VectorScalar;
dim = d;
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
bopt = ScalarVector;
bopt = BinaryOpType::ScalarVector;
dim = d;
}
@ -495,20 +495,20 @@ void binary_op(
size_t stride;
if (dim == 0 || strides[dim - 1] < 16) {
stride = 1;
bopt = General;
bopt = BinaryOpType::General;
dim = ndim;
} else {
stride = strides[dim - 1];
}
switch (bopt) {
case VectorVector:
case BinaryOpType::VectorVector:
binary_op_dispatch_dims<T, U>(a, b, out, opvv, dim, stride);
break;
case VectorScalar:
case BinaryOpType::VectorScalar:
binary_op_dispatch_dims<T, U>(a, b, out, opvs, dim, stride);
break;
case ScalarVector:
case BinaryOpType::ScalarVector:
binary_op_dispatch_dims<T, U>(a, b, out, opsv, dim, stride);
break;
default:

View File

@ -260,14 +260,14 @@ void binary_op(
set_binary_op_output_data(a, b, out_b, bopt);
// The full computation is scalar scalar so call the base op once
if (bopt == ScalarScalar) {
if (bopt == BinaryOpType::ScalarScalar) {
std::tie(*(out_a.data<U>()), *(out_b.data<U>())) =
op(*a.data<T>(), *b.data<T>());
return;
}
// The full computation is scalar vector so delegate to the op
if (bopt == ScalarVector) {
if (bopt == BinaryOpType::ScalarVector) {
opsv(
a.data<T>(),
b.data<T>(),
@ -278,7 +278,7 @@ void binary_op(
}
// The full computation is vector scalar so delegate to the op
if (bopt == VectorScalar) {
if (bopt == BinaryOpType::VectorScalar) {
opvs(
a.data<T>(),
b.data<T>(),
@ -289,7 +289,7 @@ void binary_op(
}
// The full computation is vector vector so delegate to the op
if (bopt == VectorVector) {
if (bopt == BinaryOpType::VectorVector) {
opvv(
a.data<T>(),
b.data<T>(),
@ -327,17 +327,17 @@ void binary_op(
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
int dim = ndim;
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
bopt = VectorVector;
bopt = BinaryOpType::VectorVector;
dim = d;
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
bopt = VectorScalar;
bopt = BinaryOpType::VectorScalar;
dim = d;
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
bopt = ScalarVector;
bopt = BinaryOpType::ScalarVector;
dim = d;
}
@ -347,20 +347,20 @@ void binary_op(
size_t stride;
if (dim == 0 || strides[dim - 1] < 16) {
stride = 1;
bopt = General;
bopt = BinaryOpType::General;
dim = ndim;
} else {
stride = strides[dim - 1];
}
switch (bopt) {
case VectorVector:
case BinaryOpType::VectorVector:
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opvv, dim, stride);
break;
case VectorScalar:
case BinaryOpType::VectorScalar:
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opvs, dim, stride);
break;
case ScalarVector:
case BinaryOpType::ScalarVector:
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opsv, dim, stride);
break;
default:

View File

@ -87,6 +87,7 @@ DEFAULT(Reshape)
DEFAULT(Round)
DEFAULT(Scan)
DEFAULT(Scatter)
DEFAULT(Select)
DEFAULT(Sigmoid)
DEFAULT(Sign)
DEFAULT(Sin)

View File

@ -588,4 +588,11 @@ struct LogicalOr {
};
};
struct Select {
template <typename T>
T operator()(bool condition, T x, T y) {
return condition ? x : y;
}
};
} // namespace mlx::core::detail

View File

@ -0,0 +1,72 @@
// Copyright © 2023 Apple Inc.
#include <cassert>
#include "mlx/backend/common/ternary.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
template <typename Op>
void select_op(
const array& a,
const array& b,
const array& c,
array& out,
Op op) {
switch (out.dtype()) {
case bool_:
ternary_op<bool, bool, bool, bool>(a, b, c, out, op);
break;
case uint8:
ternary_op<bool, uint8_t, uint8_t, uint8_t>(a, b, c, out, op);
break;
case uint16:
ternary_op<bool, uint16_t, uint16_t, uint16_t>(a, b, c, out, op);
break;
case uint32:
ternary_op<bool, uint32_t, uint32_t, uint32_t>(a, b, c, out, op);
break;
case uint64:
ternary_op<bool, uint64_t, uint64_t, uint64_t>(a, b, c, out, op);
break;
case int8:
ternary_op<bool, int8_t, int8_t, int8_t>(a, b, c, out, op);
break;
case int16:
ternary_op<bool, int16_t, int16_t, int16_t>(a, b, c, out, op);
break;
case int32:
ternary_op<bool, int32_t, int32_t, int32_t>(a, b, c, out, op);
break;
case int64:
ternary_op<bool, int64_t, int64_t, int64_t>(a, b, c, out, op);
break;
case float16:
ternary_op<bool, float16_t, float16_t, float16_t>(a, b, c, out, op);
break;
case float32:
ternary_op<bool, float, float, float>(a, b, c, out, op);
break;
case bfloat16:
ternary_op<bool, bfloat16_t, bfloat16_t, bfloat16_t>(a, b, c, out, op);
break;
case complex64:
ternary_op<bool, complex64_t, complex64_t, complex64_t>(a, b, c, out, op);
break;
}
}
} // namespace
void Select::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 3);
const auto& condition = inputs[0];
const auto& a = inputs[1];
const auto& b = inputs[2];
select_op(condition, a, b, out, detail::Select());
}
} // namespace mlx::core

View File

@ -0,0 +1,226 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/common/ops.h"
#include "mlx/backend/common/utils.h"
namespace mlx::core {
namespace {
// TODO: Add support for more combinations of input types.
enum class TernaryOpType {
ScalarScalarScalar,
General,
};
TernaryOpType
get_ternary_op_type(const array& a, const array& b, const array& c) {
TernaryOpType topt;
if (a.data_size() == 1 && b.data_size() == 1 && c.data_size() == 1) {
topt = TernaryOpType::ScalarScalarScalar;
} else {
topt = TernaryOpType::General;
}
return topt;
}
void set_ternary_op_output_data(
const array& a,
const array& b,
const array& c,
array& out,
TernaryOpType topt,
bool donate_with_move = false) {
switch (topt) {
case TernaryOpType::ScalarScalarScalar:
out.set_data(
allocator::malloc_or_wait(out.itemsize()), 1, b.strides(), b.flags());
break;
case TernaryOpType::General:
out.set_data(allocator::malloc_or_wait(out.nbytes()));
break;
}
}
template <typename T1, typename T2, typename T3, typename U, typename Op>
void ternary_op_dims1(
const array& a,
const array& b,
const array& c,
array& out,
Op op) {
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t c_idx = 0;
for (size_t i = 0; i < out.size(); ++i) {
dst[i] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
a_idx += a.strides()[0];
b_idx += b.strides()[0];
c_idx += c.strides()[0];
}
}
template <typename T1, typename T2, typename T3, typename U, typename Op>
void ternary_op_dims2(
const array& a,
const array& b,
const array& c,
array& out,
Op op) {
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t c_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
a_idx += a.strides()[1];
b_idx += b.strides()[1];
c_idx += c.strides()[1];
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1];
}
}
template <typename T1, typename T2, typename T3, typename U, typename Op>
void ternary_op_dims3(
const array& a,
const array& b,
const array& c,
array& out,
Op op) {
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t c_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
for (size_t k = 0; k < a.shape()[2]; ++k) {
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
a_idx += a.strides()[2];
b_idx += b.strides()[2];
c_idx += c.strides()[2];
}
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
c_idx += c.strides()[1] - c.strides()[2] * c.shape()[2];
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1];
}
}
template <typename T1, typename T2, typename T3, typename U, typename Op>
void ternary_op_dims4(
const array& a,
const array& b,
const array& c,
array& out,
Op op) {
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t c_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
for (size_t k = 0; k < a.shape()[2]; ++k) {
for (size_t ii = 0; ii < a.shape()[3]; ++ii) {
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
a_idx += a.strides()[3];
b_idx += b.strides()[3];
c_idx += c.strides()[3];
}
a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
c_idx += c.strides()[2] - c.strides()[3] * c.shape()[3];
}
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
c_idx += c.strides()[1] - c.strides()[2] * c.shape()[2];
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1];
}
}
template <typename T1, typename T2, typename T3, typename U, typename Op>
void ternary_op_dispatch_dims(
const array& a,
const array& b,
const array& c,
array& out,
Op op) {
switch (out.ndim()) {
case 1:
ternary_op_dims1<T1, T2, T3, U, Op>(a, b, c, out, op);
return;
case 2:
ternary_op_dims2<T1, T2, T3, U, Op>(a, b, c, out, op);
return;
case 3:
ternary_op_dims3<T1, T2, T3, U, Op>(a, b, c, out, op);
return;
case 4:
ternary_op_dims4<T1, T2, T3, U, Op>(a, b, c, out, op);
return;
}
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* dst = out.data<U>();
for (size_t i = 0; i < out.size(); i++) {
int a_idx = elem_to_loc(i, a.shape(), a.strides());
int b_idx = elem_to_loc(i, b.shape(), b.strides());
int c_idx = elem_to_loc(i, c.shape(), c.strides());
dst[i] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
}
}
template <typename T1, typename T2, typename T3, typename U, typename Op>
void ternary_op(
const array& a,
const array& b,
const array& c,
array& out,
Op op) {
TernaryOpType topt = get_ternary_op_type(a, b, c);
set_ternary_op_output_data(a, b, c, out, topt);
// The full computation is scalar-scalar-scalar so we call the base op once.
if (topt == TernaryOpType::ScalarScalarScalar) {
*(out.data<U>()) = op(*a.data<T1>(), *b.data<T2>(), *c.data<T3>());
return;
}
ternary_op_dispatch_dims<T1, T2, T3, U>(a, b, c, out, op);
}
} // namespace
} // namespace mlx::core

View File

@ -27,6 +27,7 @@ set(
"scan"
"softmax"
"sort"
"ternary"
"unary"
"gather"
"scatter"

View File

@ -1,6 +1,7 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/kernels/binary.h"
#include "mlx/backend/metal/kernels/ternary.h"
#include "mlx/backend/metal/kernels/unary.h"
typedef half float16_t;

View File

@ -0,0 +1,10 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
struct Select {
template <typename T>
T operator()(bool condition, T x, T y) {
return condition ? x : y;
}
};

View File

@ -0,0 +1,184 @@
// Copyright © 2023 Apple Inc.
#include <metal_integer>
#include <metal_math>
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/ternary.h"
template <typename T, typename Op>
[[kernel]] void ternary_op_g_nd1(
device const bool* a,
device const T* b,
device const T* c,
device T* d,
constant const size_t& a_strides,
constant const size_t& b_strides,
constant const size_t& c_strides,
uint index [[thread_position_in_grid]]) {
auto a_idx = elem_to_loc_1(index, a_strides);
auto b_idx = elem_to_loc_1(index, b_strides);
auto c_idx = elem_to_loc_1(index, c_strides);
d[index] = Op()(a[a_idx], b[b_idx], c[c_idx]);
}
template <typename T, typename Op>
[[kernel]] void ternary_op_g_nd2(
device const bool* a,
device const T* b,
device const T* c,
device T* d,
constant const size_t a_strides[2],
constant const size_t b_strides[2],
constant const size_t c_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2(index, b_strides);
auto c_idx = elem_to_loc_2(index, c_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
}
template <typename T, typename Op>
[[kernel]] void ternary_op_g_nd3(
device const bool* a,
device const T* b,
device const T* c,
device T* d,
constant const size_t a_strides[3],
constant const size_t b_strides[3],
constant const size_t c_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
auto c_idx = elem_to_loc_3(index, c_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
}
template <typename T, typename Op, int DIM>
[[kernel]] void ternary_op_g_nd(
device const bool* a,
device const T* b,
device const T* c,
device T* d,
constant const int shape[DIM],
constant const size_t a_strides[DIM],
constant const size_t b_strides[DIM],
constant const size_t c_strides[DIM],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_3_nd<DIM>(index, shape, a_strides, b_strides, c_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
}
template <typename T, typename Op>
[[kernel]] void ternary_op_g(
device const bool* a,
device const T* b,
device const T* c,
device T* d,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const size_t* c_strides,
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_3_nd(index, shape, a_strides, b_strides, c_strides, ndim);
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
}
#define instantiate_ternary_g(name, type, op) \
template [[host_name(name)]] \
[[kernel]] void ternary_op_g<type, op>( \
device const bool* a, \
device const type* b, \
device const type* c, \
device type* d, \
constant const int* shape, \
constant const size_t* a_strides, \
constant const size_t* b_strides, \
constant const size_t* c_strides, \
constant const int& ndim, \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
#define instantiate_ternary_g_dim(name, type, op, dims) \
template [[host_name(name "_" #dims)]] \
[[kernel]] void ternary_op_g_nd<type, op, dims>( \
device const bool* a, \
device const type* b, \
device const type* c, \
device type* d, \
constant const int shape[dims], \
constant const size_t a_strides[dims], \
constant const size_t b_strides[dims], \
constant const size_t c_strides[dims], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
#define instantiate_ternary_g_nd(name, type, op) \
template [[host_name(name "_1")]] \
[[kernel]] void ternary_op_g_nd1<type, op>( \
device const bool* a, \
device const type* b, \
device const type* c, \
device type* d, \
constant const size_t& a_strides, \
constant const size_t& b_strides, \
constant const size_t& c_strides, \
uint index [[thread_position_in_grid]]); \
template [[host_name(name "_2")]] \
[[kernel]] void ternary_op_g_nd2<type, op>( \
device const bool* a, \
device const type* b, \
device const type* c, \
device type* d, \
constant const size_t a_strides[2], \
constant const size_t b_strides[2], \
constant const size_t c_strides[2], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \
template [[host_name(name "_3")]] \
[[kernel]] void ternary_op_g_nd3<type, op>( \
device const bool* a, \
device const type* b, \
device const type* c, \
device type* d, \
constant const size_t a_strides[3], \
constant const size_t b_strides[3], \
constant const size_t c_strides[3], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
instantiate_ternary_g_dim(name, type, op, 4) \
instantiate_ternary_g_dim(name, type, op, 5) \
#define instantiate_ternary_all(name, tname, type, op) \
instantiate_ternary_g("g" #name #tname, type, op) \
instantiate_ternary_g_nd("g" #name #tname, type, op) \
#define instantiate_ternary_float(name, op) \
instantiate_ternary_all(name, float16, half, op) \
instantiate_ternary_all(name, float32, float, op) \
instantiate_ternary_all(name, bfloat16, bfloat16_t, op)
#define instantiate_ternary_types(name, op) \
instantiate_ternary_all(name, bool_, bool, op) \
instantiate_ternary_all(name, uint8, uint8_t, op) \
instantiate_ternary_all(name, uint16, uint16_t, op) \
instantiate_ternary_all(name, uint32, uint32_t, op) \
instantiate_ternary_all(name, uint64, uint64_t, op) \
instantiate_ternary_all(name, int8, int8_t, op) \
instantiate_ternary_all(name, int16, int16_t, op) \
instantiate_ternary_all(name, int32, int32_t, op) \
instantiate_ternary_all(name, int64, int64_t, op) \
instantiate_ternary_all(name, complex64, complex64_t, op) \
instantiate_ternary_float(name, op)
instantiate_ternary_types(select, Select)

View File

@ -91,6 +91,30 @@ inline size_t elem_to_loc(
return loc;
}
template <int NDIM>
inline uint3 elem_to_loc_3_nd(
uint3 elem,
constant const int shape[NDIM],
constant const size_t a_strides[NDIM],
constant const size_t b_strides[NDIM],
constant const size_t c_strides[NDIM]) {
uint3 loc = {
static_cast<uint>(
elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
static_cast<uint>(
elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2]),
static_cast<uint>(
elem.x * c_strides[NDIM - 1] + elem.y * c_strides[NDIM - 2])};
for (int d = NDIM - 3; d >= 0; --d) {
uint l = elem.z % shape[d];
loc.x += l * a_strides[d];
loc.y += l * b_strides[d];
loc.z += l * c_strides[d];
elem.z /= shape[d];
}
return loc;
}
template <int NDIM>
inline uint2 elem_to_loc_2_nd(
uint3 elem,
@ -150,6 +174,30 @@ inline size_t elem_to_loc(
return loc;
}
inline uint3 elem_to_loc_3_nd(
uint3 elem,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const size_t* c_strides,
int ndim) {
uint3 loc = {
static_cast<uint>(
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
static_cast<uint>(
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2]),
static_cast<uint>(
elem.x * c_strides[ndim - 1] + elem.y * c_strides[ndim - 2])};
for (int d = ndim - 3; d >= 0; --d) {
uint l = elem.z % shape[d];
loc.x += l * a_strides[d];
loc.y += l * b_strides[d];
loc.z += l * c_strides[d];
elem.z /= shape[d];
}
return loc;
}
inline uint2 elem_to_loc_2_nd(
uint3 elem,
constant const int* shape,

View File

@ -6,6 +6,7 @@
#include <sstream>
#include "mlx/backend/common/binary.h"
#include "mlx/backend/common/ternary.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/defines.h"
@ -43,24 +44,25 @@ void binary_op(
std::ostringstream kname;
switch (bopt) {
case ScalarScalar:
case BinaryOpType::ScalarScalar:
kname << "ss";
break;
case ScalarVector:
case BinaryOpType::ScalarVector:
kname << "sv";
break;
case VectorScalar:
case BinaryOpType::VectorScalar:
kname << "vs";
break;
case VectorVector:
case BinaryOpType::VectorVector:
kname << "vv";
break;
case General:
case BinaryOpType::General:
kname << "g";
break;
}
kname << op << type_to_name(a);
if (bopt == General && shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
if (bopt == BinaryOpType::General &&
shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
kname << "_" << shape.size();
}
@ -80,7 +82,7 @@ void binary_op(
set_array_buffer(compute_encoder, outputs[0], 2);
set_array_buffer(compute_encoder, outputs[1], 3);
if (bopt == General) {
if (bopt == BinaryOpType::General) {
auto ndim = shape.size();
if (ndim > 3) {
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4);
@ -141,24 +143,25 @@ void binary_op(
std::ostringstream kname;
switch (bopt) {
case ScalarScalar:
case BinaryOpType::ScalarScalar:
kname << "ss";
break;
case ScalarVector:
case BinaryOpType::ScalarVector:
kname << "sv";
break;
case VectorScalar:
case BinaryOpType::VectorScalar:
kname << "vs";
break;
case VectorVector:
case BinaryOpType::VectorVector:
kname << "vv";
break;
case General:
case BinaryOpType::General:
kname << "g";
break;
}
kname << op << type_to_name(a);
if (bopt == General && shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
if (bopt == BinaryOpType::General &&
shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
kname << "_" << shape.size();
}
@ -173,7 +176,7 @@ void binary_op(
set_array_buffer(compute_encoder, donate_b ? out : b, 1);
set_array_buffer(compute_encoder, out, 2);
if (bopt == General) {
if (bopt == BinaryOpType::General) {
auto ndim = shape.size();
if (ndim > 3) {
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 3);
@ -202,7 +205,8 @@ void binary_op(
compute_encoder->dispatchThreads(grid_dims, group_dims);
} else {
// Launch a 1D grid of threads
size_t nthreads = bopt == General ? out.size() : out.data_size();
size_t nthreads =
bopt == BinaryOpType::General ? out.size() : out.data_size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
@ -213,6 +217,86 @@ void binary_op(
}
}
void ternary_op(
const std::vector<array>& inputs,
array& out,
const std::string op) {
assert(inputs.size() == 3);
auto& a = inputs[0];
auto& b = inputs[1];
auto& c = inputs[2];
TernaryOpType topt = get_ternary_op_type(a, b, c);
set_ternary_op_output_data(a, b, c, out, topt, true /* donate_with_move */);
if (out.size() == 0) {
return;
}
// Try to collapse contiguous dims
auto [shape, strides] = collapse_contiguous_dims(a, b, c, out);
auto& strides_a = strides[0];
auto& strides_b = strides[1];
auto& strides_c = strides[2];
auto& strides_out = strides[3];
std::ostringstream kname;
kname << "g";
kname << op << type_to_name(b);
if (topt == TernaryOpType::General &&
shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
kname << "_" << shape.size();
}
auto& s = out.primitive().stream();
auto& d = metal::device(s.device);
auto kernel = d.get_kernel(kname.str());
auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, a, 0);
set_array_buffer(compute_encoder, b, 1);
set_array_buffer(compute_encoder, c, 2);
set_array_buffer(compute_encoder, out, 3);
auto ndim = shape.size();
if (ndim > 3) {
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4);
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6);
compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 7);
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
compute_encoder->setBytes(&ndim, sizeof(int), 8);
}
} else if (ndim > 0) {
// The shape is implicit in the grid for <= 3D
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 6);
} else {
// For 0-dim we still need to bind something to these buffers since the
// current ternary kernels always access the strides.
size_t dummy_stride = 0;
int dummy_shape = 0;
compute_encoder->setBytes(&dummy_shape, sizeof(int), 4);
compute_encoder->setBytes(&dummy_stride, sizeof(size_t), 5);
compute_encoder->setBytes(&dummy_stride, sizeof(size_t), 6);
compute_encoder->setBytes(&dummy_stride, sizeof(size_t), 7);
compute_encoder->setBytes(&ndim, sizeof(int), 8);
}
// Launch up to 3D grid of threads
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t rest = out.size() / (dim0 * dim1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) {
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
}
MTL::Size group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
void unary_op(
const std::vector<array>& inputs,
array& out,
@ -619,6 +703,10 @@ void Multiply::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "mul");
}
void Select::eval_gpu(const std::vector<array>& inputs, array& out) {
ternary_op(inputs, out, "select");
}
void Negative::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "neg");
}

View File

@ -80,6 +80,7 @@ NO_GPU(Reshape)
NO_GPU(Round)
NO_GPU(Scan)
NO_GPU(Scatter)
NO_GPU(Select)
NO_GPU(Sigmoid)
NO_GPU(Sign)
NO_GPU(Sin)

View File

@ -47,6 +47,10 @@ bool is_binary(const Primitive& p) {
typeid(p) == typeid(Subtract));
}
bool is_ternary(const Primitive& p) {
return typeid(p) == typeid(Select);
}
bool is_broadcast(const Primitive& p) {
return typeid(p) == typeid(Broadcast);
}
@ -60,14 +64,16 @@ bool is_reduction(const Primitive& p) {
}
bool is_fusable(const Primitive& p) {
return is_unary(p) || is_binary(p) || is_broadcast(p) || is_noop(p);
return is_unary(p) || is_binary(p) || is_ternary(p) || is_broadcast(p) ||
is_noop(p);
}
bool allows_shapeless(const Primitive& p) {
return typeid(p) == typeid(Compiled) || is_unary(p) || is_binary(p) ||
is_noop(p) || is_reduction(p) || typeid(p) == typeid(Softmax) ||
typeid(p) == typeid(Sort) || typeid(p) == typeid(ArgSort) ||
typeid(p) == typeid(ArgPartition) || typeid(p) == typeid(Partition);
typeid(p) == typeid(ArgPartition) || typeid(p) == typeid(Partition) ||
typeid(p) == typeid(Select);
}
Compiled::Compiled(

View File

@ -1149,13 +1149,20 @@ array isneginf(const array& a, StreamOrDevice s /* = {} */) {
}
array where(
const array& condition,
const array& x,
const array& y,
const array& a,
const array& b,
const array& c,
StreamOrDevice s /* = {} */) {
// TODO, fix this to handle the NaN case when x has infs
auto mask = astype(condition, bool_, s);
return add(multiply(x, mask, s), multiply(y, logical_not(mask, s), s), s);
auto condition = astype(a, bool_, s);
Dtype out_dtype = promote_types(b.dtype(), c.dtype());
auto inputs = broadcast_arrays(
{condition, astype(b, out_dtype, s), astype(c, out_dtype, s)}, s);
return array(
inputs[0].shape(),
out_dtype,
std::make_unique<Select>(to_stream(s)),
inputs);
}
array allclose(

View File

@ -48,6 +48,54 @@ std::tuple<array, array, int> vmap_binary_op(
return {a, b, to_ax};
}
std::tuple<array, array, array, int> vmap_ternary_op(
const std::vector<array>& inputs,
const std::vector<int>& axes,
const Stream& stream) {
assert(inputs.size() == 3);
assert(axes.size() == 3);
auto a = inputs[0];
auto b = inputs[1];
auto c = inputs[2];
int ndim = std::max(
{a.ndim() + (axes[0] == -1),
b.ndim() + (axes[1] == -1),
c.ndim() + (axes[2] == -1)});
auto expand_dims = [stream, ndim](auto in) {
auto shape = in.shape();
shape.insert(shape.begin(), ndim - shape.size(), 1);
return reshape(in, shape, stream);
};
int to_ax = (ndim - a.ndim()) + axes[0];
int from_ax1 = (ndim - b.ndim()) + axes[1];
int from_ax2 = (ndim - c.ndim()) + axes[2];
a = expand_dims(a);
b = expand_dims(b);
c = expand_dims(c);
auto find_tdims = [](auto x, int to_ax, int from_ax) {
std::vector<int> tdims(x.ndim());
std::iota(tdims.begin(), tdims.end(), 0);
tdims.erase(tdims.begin() + from_ax);
tdims.insert(tdims.begin() + to_ax, from_ax);
return tdims;
};
if (to_ax != from_ax1) {
std::vector<int> tdims = find_tdims(b, to_ax, from_ax1);
b = transpose(b, tdims, stream);
}
if (to_ax != from_ax2) {
std::vector<int> tdims = find_tdims(c, to_ax, from_ax2);
c = transpose(c, tdims, stream);
}
return {a, b, c, to_ax};
}
} // namespace
std::vector<array> Primitive::jvp(
@ -1775,6 +1823,76 @@ std::pair<std::vector<array>, std::vector<int>> Multiply::vmap(
return {{multiply(a, b, stream())}, {to_ax}};
}
std::vector<array> Select::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
assert(primals.size() == 3);
assert(tangents.size() == 3);
auto jvp_fun = [&](int i) {
int arg = argnums[i];
if (arg == 0) {
return zeros_like(primals[0], stream());
} else if (arg == 1) {
return multiply(
astype(primals[0], tangents[1].dtype(), stream()),
tangents[1],
stream());
} else {
return multiply(
astype(
logical_not(primals[0], stream()), tangents[2].dtype(), stream()),
tangents[2],
stream());
}
};
array jvp = jvp_fun(argnums[0]);
for (int i = 1; i < argnums.size(); i++) {
jvp = add(jvp, jvp_fun(argnums[i]));
}
return {jvp};
}
std::vector<array> Select::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
assert(primals.size() == 3);
assert(cotangents.size() == 1);
std::vector<array> vjps;
for (auto arg : argnums) {
if (arg == 0) {
vjps.push_back(zeros_like(primals[0], stream()));
} else if (arg == 1) {
vjps.push_back(multiply(
astype(primals[0], cotangents[0].dtype(), stream()),
cotangents[0],
stream()));
} else if (arg == 2) {
vjps.push_back(multiply(
astype(
logical_not(primals[0], stream()),
cotangents[0].dtype(),
stream()),
cotangents[0],
stream()));
}
}
return vjps;
}
std::pair<std::vector<array>, std::vector<int>> Select::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto [a, b, c, to_ax] = vmap_ternary_op(inputs, axes, stream());
return {{where(a, b, c, stream())}, {to_ax}};
}
std::vector<array> Negative::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,

View File

@ -719,6 +719,23 @@ class DivMod : public Primitive {
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
};
class Select : public UnaryPrimitive {
public:
explicit Select(Stream stream) : UnaryPrimitive(stream){};
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(Select)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class Remainder : public UnaryPrimitive {
public:
explicit Remainder(Stream stream) : UnaryPrimitive(stream){};

View File

@ -1075,6 +1075,37 @@ TEST_CASE("test jvp from vjp") {
CHECK(compute_derivs(subtract));
CHECK(compute_derivs(power));
}
// Conditional selection element-wise op
{
auto condition = random::randint(0, 2, {5, 10});
auto x = random::uniform({5, 10});
auto y = random::uniform({5, 10});
eval(condition, x, y);
auto compute_derivs = [&condition, &x, &y](auto fn) {
auto fn_wrap = [&fn](std::vector<array> inputs) {
return std::vector<array>{
fn(inputs[0], inputs[1], inputs[2], default_device())};
};
// Compute vjp and add results
auto vjps = vjp(fn_wrap, {condition, x, y}, {ones(x.shape())}).second;
auto vjp_out = add(add(vjps[0], vjps[1]), vjps[2]);
// Compute jvp
array jvp_out =
jvp(fn_wrap,
{condition, x, y},
{ones(condition.shape()), ones(y.shape()), ones(x.shape())})
.second[0];
array result = array_equal(vjp_out, jvp_out);
return result.item<bool>();
};
CHECK(compute_derivs(where));
}
}
TEST_CASE("test complex gradients") {

View File

@ -2193,6 +2193,8 @@ TEST_CASE("test power") {
}
TEST_CASE("test where") {
const float inf = std::numeric_limits<float>::infinity();
array condition(true);
array x(1.0f);
array y(0.0f);
@ -2224,6 +2226,49 @@ TEST_CASE("test where") {
out = where(condition, x, y);
expected = array({1, 2, 2, 1}, {2, 2});
CHECK(array_equal(where(condition, x, y), expected).item<bool>());
condition = array(true);
x = array({1, 2, 3});
y = array({3, 6, 13});
CHECK(array_equal(where(condition, x, y), array({1, 2, 3})).item<bool>());
condition = array(false);
x = array({1, 2, 3});
y = array({3, 6, 13});
CHECK(array_equal(where(condition, x, y), array({3, 6, 13})).item<bool>());
condition = array({1, 1, 0});
x = array({1, 2, 3});
y = array({11, 12, 13});
CHECK(array_equal(where(condition, x, y), array({1, 2, 13})).item<bool>());
condition = array({true, false}, {2, 1, 1});
x = array({1, 2, 3, 4}, {2, 1, 2});
y = array({11, 22, 33, 44}, {2, 2, 1});
expected = array({1, 2, 1, 2, 33, 33, 44, 44}, {2, 2, 2});
CHECK(array_equal(where(condition, x, y), expected).item<bool>());
condition = array({true, false, false});
x = array({inf, 2.0, 3.0});
y = array({10.0, 20.0, -inf});
CHECK(array_equal(where(condition, x, y), array({inf, 20.0, -inf}))
.item<bool>());
// 4-dim optimized case.
condition = array({false});
x = array({1, 2}, {2, 1, 1, 1});
y = array({3, 4}, {1, 1, 2, 1});
CHECK(array_equal(where(condition, x, y), array({3, 4, 3, 4}, {2, 1, 2, 1}))
.item<bool>());
// 5-dim optimized case.
condition = array({true, false}, {2, 1, 1, 1, 1});
x = array({1, 2, 3, 4}, {2, 1, 1, 1, 2});
y = array({11, 22}, {1, 1, 2, 1, 1});
CHECK(array_equal(
where(condition, x, y),
array({1, 2, 1, 2, 11, 11, 22, 22}, {2, 1, 2, 1, 2}))
.item<bool>());
}
TEST_CASE("test stack") {

View File

@ -138,6 +138,70 @@ TEST_CASE("test simple vmap") {
CHECK(array_equal(out, x + y).item<bool>());
}
// vmap where (ternary op)
{
auto fun = [](std::vector<array> inputs) {
auto out = where(inputs[0], inputs[1], inputs[2]);
return std::vector<array>{out};
};
auto vfun = vmap(fun);
array cond({true, false}, {2, 1});
array x({1.0, 2.0}, {2, 1});
array y({2.0, 4.0}, {2, 1});
auto out = vfun({cond, x, y})[0];
CHECK(array_equal(out, array({1.0, 4.0}, {2, 1})).item<bool>());
cond = array({true, true, false}, {1, 3});
x = ones({2, 1, 3});
y = zeros({3, 2});
vfun = vmap(fun, {1, 2, 0});
out = vfun({cond, x, y})[0];
CHECK(
array_equal(out, array({1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0}, {3, 2, 2}))
.item<bool>());
vfun = vmap(fun, {1, 2, 0}, {1});
out = vfun({cond, x, y})[0];
CHECK(
array_equal(out, array({1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0}, {2, 3, 2}))
.item<bool>());
cond = array({true, false});
x = array(2.);
y = ones({3, 2});
vfun = vmap(fun, {-1, -1, 0});
out = vfun({cond, x, y})[0];
CHECK(array_equal(out, array({2, 1, 2, 1, 2, 1}, {3, 2})).item<bool>());
cond = array({true, false});
x = ones({3, 2});
y = array(2.);
vfun = vmap(fun, {-1, 0, -1});
out = vfun({cond, x, y})[0];
CHECK(array_equal(out, array({1, 2, 1, 2, 1, 2}, {3, 2})).item<bool>());
CHECK_THROWS_AS(vmap(fun, {-1, -1, -1}, {0}), std::invalid_argument);
CHECK_THROWS_AS(vmap(fun, {-1, 0, -1}, {-1}), std::invalid_argument);
CHECK_THROWS_AS(vmap(fun, {-1, -1, 0}, {-1}), std::invalid_argument);
CHECK_THROWS_AS(vmap(fun, {0, -1, -1}, {-1}), std::invalid_argument);
cond = array({true, false});
x = array(1.);
y = array(2.);
vfun = vmap(fun, {-1, -1, -1}, {-1});
out = vfun({cond, x, y})[0];
CHECK(array_equal(out, array({1.0, 2.0})).item<bool>());
cond = array({1, 1, 1, 0, 0, 0}, {3, 2, 1});
x = ones({3, 2, 1});
y = full({3, 2, 1}, 2);
vfun = vmap(vmap(fun));
out = vfun({cond, x, y})[0];
CHECK(array_equal(out, array({1, 1, 1, 2, 2, 2}, {3, 2, 1})).item<bool>());
}
// vmap with capturing closure
{
auto x = add(add(ones({2}), zeros({2})), zeros({2}));