mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Implement the 'where' primitive for conditional selection (#664)
This commit is contained in:
parent
ad4a45e615
commit
126c9869c8
@ -73,6 +73,7 @@ void time_unary_ops() {
|
||||
|
||||
void time_binary_ops() {
|
||||
int M = 1000, N = 100, K = 10;
|
||||
auto condition = random::randint(0, 2, {M, N, K});
|
||||
auto a = random::uniform({M, N, K});
|
||||
auto b = random::uniform({M, N, K});
|
||||
auto device = default_device();
|
||||
@ -84,7 +85,9 @@ void time_binary_ops() {
|
||||
TIME(divide, a, b, device);
|
||||
TIME(maximum, a, b, device);
|
||||
TIME(minimum, a, b, device);
|
||||
TIME(where, condition, a, b, device);
|
||||
|
||||
condition = array({true});
|
||||
b = random::uniform({1});
|
||||
eval(b);
|
||||
TIMEM("scalar", add, a, b, device);
|
||||
@ -93,7 +96,9 @@ void time_binary_ops() {
|
||||
TIMEM("scalar", multiply, a, b, device);
|
||||
TIMEM("vector-scalar", divide, a, b, device);
|
||||
TIMEM("scalar-vector", divide, b, a, device);
|
||||
TIMEM("scalar-vector", where, condition, a, b, device);
|
||||
|
||||
condition = broadcast_to(array({true}), {1000, 100});
|
||||
a = broadcast_to(random::uniform({1}), {1000, 100});
|
||||
b = broadcast_to(random::uniform({1}), {1000, 100});
|
||||
eval(a, b);
|
||||
@ -101,6 +106,7 @@ void time_binary_ops() {
|
||||
TIMEM("scalar-scalar broadcast", subtract, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", multiply, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", divide, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", where, condition, a, b, device);
|
||||
}
|
||||
|
||||
void time_strided_ops() {
|
||||
|
@ -64,6 +64,7 @@ DEFAULT(Reshape)
|
||||
DEFAULT(Remainder)
|
||||
DEFAULT(Round)
|
||||
DEFAULT(Scatter)
|
||||
DEFAULT(Select)
|
||||
DEFAULT(Sigmoid)
|
||||
DEFAULT(Sign)
|
||||
DEFAULT(Slice)
|
||||
|
@ -43,6 +43,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
|
||||
|
@ -9,7 +9,7 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
enum BinaryOpType {
|
||||
enum class BinaryOpType {
|
||||
ScalarScalar,
|
||||
ScalarVector,
|
||||
VectorScalar,
|
||||
@ -20,17 +20,17 @@ enum BinaryOpType {
|
||||
BinaryOpType get_binary_op_type(const array& a, const array& b) {
|
||||
BinaryOpType bopt;
|
||||
if (a.data_size() == 1 && b.data_size() == 1) {
|
||||
bopt = ScalarScalar;
|
||||
bopt = BinaryOpType::ScalarScalar;
|
||||
} else if (a.data_size() == 1 && b.flags().contiguous) {
|
||||
bopt = ScalarVector;
|
||||
bopt = BinaryOpType::ScalarVector;
|
||||
} else if (b.data_size() == 1 && a.flags().contiguous) {
|
||||
bopt = VectorScalar;
|
||||
bopt = BinaryOpType::VectorScalar;
|
||||
} else if (
|
||||
a.flags().row_contiguous && b.flags().row_contiguous ||
|
||||
a.flags().col_contiguous && b.flags().col_contiguous) {
|
||||
bopt = VectorVector;
|
||||
bopt = BinaryOpType::VectorVector;
|
||||
} else {
|
||||
bopt = General;
|
||||
bopt = BinaryOpType::General;
|
||||
}
|
||||
return bopt;
|
||||
}
|
||||
@ -42,11 +42,11 @@ void set_binary_op_output_data(
|
||||
BinaryOpType bopt,
|
||||
bool donate_with_move = false) {
|
||||
switch (bopt) {
|
||||
case ScalarScalar:
|
||||
case BinaryOpType::ScalarScalar:
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
|
||||
break;
|
||||
case ScalarVector:
|
||||
case BinaryOpType::ScalarVector:
|
||||
if (b.is_donatable() && b.itemsize() == out.itemsize()) {
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(b);
|
||||
@ -61,7 +61,7 @@ void set_binary_op_output_data(
|
||||
b.flags());
|
||||
}
|
||||
break;
|
||||
case VectorScalar:
|
||||
case BinaryOpType::VectorScalar:
|
||||
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(a);
|
||||
@ -76,7 +76,7 @@ void set_binary_op_output_data(
|
||||
a.flags());
|
||||
}
|
||||
break;
|
||||
case VectorVector:
|
||||
case BinaryOpType::VectorVector:
|
||||
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(a);
|
||||
@ -97,7 +97,7 @@ void set_binary_op_output_data(
|
||||
a.flags());
|
||||
}
|
||||
break;
|
||||
case General:
|
||||
case BinaryOpType::General:
|
||||
if (a.is_donatable() && a.flags().row_contiguous &&
|
||||
a.itemsize() == out.itemsize() && a.size() == out.size()) {
|
||||
if (donate_with_move) {
|
||||
@ -424,25 +424,25 @@ void binary_op(
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
// The full computation is scalar scalar so call the base op once
|
||||
if (bopt == ScalarScalar) {
|
||||
if (bopt == BinaryOpType::ScalarScalar) {
|
||||
*(out.data<U>()) = op(*a.data<T>(), *b.data<T>());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is scalar vector so delegate to the op
|
||||
if (bopt == ScalarVector) {
|
||||
if (bopt == BinaryOpType::ScalarVector) {
|
||||
opsv(a.data<T>(), b.data<T>(), out.data<U>(), b.data_size());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is vector scalar so delegate to the op
|
||||
if (bopt == VectorScalar) {
|
||||
if (bopt == BinaryOpType::VectorScalar) {
|
||||
opvs(a.data<T>(), b.data<T>(), out.data<U>(), a.data_size());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is vector vector so delegate to the op
|
||||
if (bopt == VectorVector) {
|
||||
if (bopt == BinaryOpType::VectorVector) {
|
||||
opvv(a.data<T>(), b.data<T>(), out.data<U>(), out.size());
|
||||
return;
|
||||
}
|
||||
@ -475,17 +475,17 @@ void binary_op(
|
||||
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
|
||||
int dim = ndim;
|
||||
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
|
||||
bopt = VectorVector;
|
||||
bopt = BinaryOpType::VectorVector;
|
||||
dim = d;
|
||||
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
|
||||
// contiguous
|
||||
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
|
||||
bopt = VectorScalar;
|
||||
bopt = BinaryOpType::VectorScalar;
|
||||
dim = d;
|
||||
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
|
||||
// contiguous
|
||||
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
|
||||
bopt = ScalarVector;
|
||||
bopt = BinaryOpType::ScalarVector;
|
||||
dim = d;
|
||||
}
|
||||
|
||||
@ -495,20 +495,20 @@ void binary_op(
|
||||
size_t stride;
|
||||
if (dim == 0 || strides[dim - 1] < 16) {
|
||||
stride = 1;
|
||||
bopt = General;
|
||||
bopt = BinaryOpType::General;
|
||||
dim = ndim;
|
||||
} else {
|
||||
stride = strides[dim - 1];
|
||||
}
|
||||
|
||||
switch (bopt) {
|
||||
case VectorVector:
|
||||
case BinaryOpType::VectorVector:
|
||||
binary_op_dispatch_dims<T, U>(a, b, out, opvv, dim, stride);
|
||||
break;
|
||||
case VectorScalar:
|
||||
case BinaryOpType::VectorScalar:
|
||||
binary_op_dispatch_dims<T, U>(a, b, out, opvs, dim, stride);
|
||||
break;
|
||||
case ScalarVector:
|
||||
case BinaryOpType::ScalarVector:
|
||||
binary_op_dispatch_dims<T, U>(a, b, out, opsv, dim, stride);
|
||||
break;
|
||||
default:
|
||||
|
@ -260,14 +260,14 @@ void binary_op(
|
||||
set_binary_op_output_data(a, b, out_b, bopt);
|
||||
|
||||
// The full computation is scalar scalar so call the base op once
|
||||
if (bopt == ScalarScalar) {
|
||||
if (bopt == BinaryOpType::ScalarScalar) {
|
||||
std::tie(*(out_a.data<U>()), *(out_b.data<U>())) =
|
||||
op(*a.data<T>(), *b.data<T>());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is scalar vector so delegate to the op
|
||||
if (bopt == ScalarVector) {
|
||||
if (bopt == BinaryOpType::ScalarVector) {
|
||||
opsv(
|
||||
a.data<T>(),
|
||||
b.data<T>(),
|
||||
@ -278,7 +278,7 @@ void binary_op(
|
||||
}
|
||||
|
||||
// The full computation is vector scalar so delegate to the op
|
||||
if (bopt == VectorScalar) {
|
||||
if (bopt == BinaryOpType::VectorScalar) {
|
||||
opvs(
|
||||
a.data<T>(),
|
||||
b.data<T>(),
|
||||
@ -289,7 +289,7 @@ void binary_op(
|
||||
}
|
||||
|
||||
// The full computation is vector vector so delegate to the op
|
||||
if (bopt == VectorVector) {
|
||||
if (bopt == BinaryOpType::VectorVector) {
|
||||
opvv(
|
||||
a.data<T>(),
|
||||
b.data<T>(),
|
||||
@ -327,17 +327,17 @@ void binary_op(
|
||||
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
|
||||
int dim = ndim;
|
||||
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
|
||||
bopt = VectorVector;
|
||||
bopt = BinaryOpType::VectorVector;
|
||||
dim = d;
|
||||
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
|
||||
// contiguous
|
||||
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
|
||||
bopt = VectorScalar;
|
||||
bopt = BinaryOpType::VectorScalar;
|
||||
dim = d;
|
||||
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
|
||||
// contiguous
|
||||
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
|
||||
bopt = ScalarVector;
|
||||
bopt = BinaryOpType::ScalarVector;
|
||||
dim = d;
|
||||
}
|
||||
|
||||
@ -347,20 +347,20 @@ void binary_op(
|
||||
size_t stride;
|
||||
if (dim == 0 || strides[dim - 1] < 16) {
|
||||
stride = 1;
|
||||
bopt = General;
|
||||
bopt = BinaryOpType::General;
|
||||
dim = ndim;
|
||||
} else {
|
||||
stride = strides[dim - 1];
|
||||
}
|
||||
|
||||
switch (bopt) {
|
||||
case VectorVector:
|
||||
case BinaryOpType::VectorVector:
|
||||
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opvv, dim, stride);
|
||||
break;
|
||||
case VectorScalar:
|
||||
case BinaryOpType::VectorScalar:
|
||||
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opvs, dim, stride);
|
||||
break;
|
||||
case ScalarVector:
|
||||
case BinaryOpType::ScalarVector:
|
||||
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opsv, dim, stride);
|
||||
break;
|
||||
default:
|
||||
|
@ -87,6 +87,7 @@ DEFAULT(Reshape)
|
||||
DEFAULT(Round)
|
||||
DEFAULT(Scan)
|
||||
DEFAULT(Scatter)
|
||||
DEFAULT(Select)
|
||||
DEFAULT(Sigmoid)
|
||||
DEFAULT(Sign)
|
||||
DEFAULT(Sin)
|
||||
|
@ -588,4 +588,11 @@ struct LogicalOr {
|
||||
};
|
||||
};
|
||||
|
||||
struct Select {
|
||||
template <typename T>
|
||||
T operator()(bool condition, T x, T y) {
|
||||
return condition ? x : y;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::detail
|
||||
|
72
mlx/backend/common/select.cpp
Normal file
72
mlx/backend/common/select.cpp
Normal file
@ -0,0 +1,72 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/common/ternary.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename Op>
|
||||
void select_op(
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
array& out,
|
||||
Op op) {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
ternary_op<bool, bool, bool, bool>(a, b, c, out, op);
|
||||
break;
|
||||
case uint8:
|
||||
ternary_op<bool, uint8_t, uint8_t, uint8_t>(a, b, c, out, op);
|
||||
break;
|
||||
case uint16:
|
||||
ternary_op<bool, uint16_t, uint16_t, uint16_t>(a, b, c, out, op);
|
||||
break;
|
||||
case uint32:
|
||||
ternary_op<bool, uint32_t, uint32_t, uint32_t>(a, b, c, out, op);
|
||||
break;
|
||||
case uint64:
|
||||
ternary_op<bool, uint64_t, uint64_t, uint64_t>(a, b, c, out, op);
|
||||
break;
|
||||
case int8:
|
||||
ternary_op<bool, int8_t, int8_t, int8_t>(a, b, c, out, op);
|
||||
break;
|
||||
case int16:
|
||||
ternary_op<bool, int16_t, int16_t, int16_t>(a, b, c, out, op);
|
||||
break;
|
||||
case int32:
|
||||
ternary_op<bool, int32_t, int32_t, int32_t>(a, b, c, out, op);
|
||||
break;
|
||||
case int64:
|
||||
ternary_op<bool, int64_t, int64_t, int64_t>(a, b, c, out, op);
|
||||
break;
|
||||
case float16:
|
||||
ternary_op<bool, float16_t, float16_t, float16_t>(a, b, c, out, op);
|
||||
break;
|
||||
case float32:
|
||||
ternary_op<bool, float, float, float>(a, b, c, out, op);
|
||||
break;
|
||||
case bfloat16:
|
||||
ternary_op<bool, bfloat16_t, bfloat16_t, bfloat16_t>(a, b, c, out, op);
|
||||
break;
|
||||
case complex64:
|
||||
ternary_op<bool, complex64_t, complex64_t, complex64_t>(a, b, c, out, op);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Select::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 3);
|
||||
const auto& condition = inputs[0];
|
||||
const auto& a = inputs[1];
|
||||
const auto& b = inputs[2];
|
||||
select_op(condition, a, b, out, detail::Select());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
226
mlx/backend/common/ternary.h
Normal file
226
mlx/backend/common/ternary.h
Normal file
@ -0,0 +1,226 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/ops.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
// TODO: Add support for more combinations of input types.
|
||||
enum class TernaryOpType {
|
||||
ScalarScalarScalar,
|
||||
General,
|
||||
};
|
||||
|
||||
TernaryOpType
|
||||
get_ternary_op_type(const array& a, const array& b, const array& c) {
|
||||
TernaryOpType topt;
|
||||
if (a.data_size() == 1 && b.data_size() == 1 && c.data_size() == 1) {
|
||||
topt = TernaryOpType::ScalarScalarScalar;
|
||||
} else {
|
||||
topt = TernaryOpType::General;
|
||||
}
|
||||
return topt;
|
||||
}
|
||||
|
||||
void set_ternary_op_output_data(
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
array& out,
|
||||
TernaryOpType topt,
|
||||
bool donate_with_move = false) {
|
||||
switch (topt) {
|
||||
case TernaryOpType::ScalarScalarScalar:
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(out.itemsize()), 1, b.strides(), b.flags());
|
||||
break;
|
||||
case TernaryOpType::General:
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T1, typename T2, typename T3, typename U, typename Op>
|
||||
void ternary_op_dims1(
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
array& out,
|
||||
Op op) {
|
||||
const T1* a_ptr = a.data<T1>();
|
||||
const T2* b_ptr = b.data<T2>();
|
||||
const T3* c_ptr = c.data<T3>();
|
||||
|
||||
U* dst = out.data<U>();
|
||||
size_t a_idx = 0;
|
||||
size_t b_idx = 0;
|
||||
size_t c_idx = 0;
|
||||
for (size_t i = 0; i < out.size(); ++i) {
|
||||
dst[i] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
|
||||
a_idx += a.strides()[0];
|
||||
b_idx += b.strides()[0];
|
||||
c_idx += c.strides()[0];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T1, typename T2, typename T3, typename U, typename Op>
|
||||
void ternary_op_dims2(
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
array& out,
|
||||
Op op) {
|
||||
const T1* a_ptr = a.data<T1>();
|
||||
const T2* b_ptr = b.data<T2>();
|
||||
const T3* c_ptr = c.data<T3>();
|
||||
|
||||
U* dst = out.data<U>();
|
||||
size_t a_idx = 0;
|
||||
size_t b_idx = 0;
|
||||
size_t c_idx = 0;
|
||||
size_t out_idx = 0;
|
||||
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
|
||||
a_idx += a.strides()[1];
|
||||
b_idx += b.strides()[1];
|
||||
c_idx += c.strides()[1];
|
||||
}
|
||||
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||
c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T1, typename T2, typename T3, typename U, typename Op>
|
||||
void ternary_op_dims3(
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
array& out,
|
||||
Op op) {
|
||||
const T1* a_ptr = a.data<T1>();
|
||||
const T2* b_ptr = b.data<T2>();
|
||||
const T3* c_ptr = c.data<T3>();
|
||||
U* dst = out.data<U>();
|
||||
size_t a_idx = 0;
|
||||
size_t b_idx = 0;
|
||||
size_t c_idx = 0;
|
||||
size_t out_idx = 0;
|
||||
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||
for (size_t k = 0; k < a.shape()[2]; ++k) {
|
||||
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
|
||||
a_idx += a.strides()[2];
|
||||
b_idx += b.strides()[2];
|
||||
c_idx += c.strides()[2];
|
||||
}
|
||||
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
|
||||
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
|
||||
c_idx += c.strides()[1] - c.strides()[2] * c.shape()[2];
|
||||
}
|
||||
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||
c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T1, typename T2, typename T3, typename U, typename Op>
|
||||
void ternary_op_dims4(
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
array& out,
|
||||
Op op) {
|
||||
const T1* a_ptr = a.data<T1>();
|
||||
const T2* b_ptr = b.data<T2>();
|
||||
const T3* c_ptr = c.data<T3>();
|
||||
|
||||
U* dst = out.data<U>();
|
||||
size_t a_idx = 0;
|
||||
size_t b_idx = 0;
|
||||
size_t c_idx = 0;
|
||||
size_t out_idx = 0;
|
||||
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||
for (size_t k = 0; k < a.shape()[2]; ++k) {
|
||||
for (size_t ii = 0; ii < a.shape()[3]; ++ii) {
|
||||
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
|
||||
a_idx += a.strides()[3];
|
||||
b_idx += b.strides()[3];
|
||||
c_idx += c.strides()[3];
|
||||
}
|
||||
a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
|
||||
b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
|
||||
c_idx += c.strides()[2] - c.strides()[3] * c.shape()[3];
|
||||
}
|
||||
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
|
||||
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
|
||||
c_idx += c.strides()[1] - c.strides()[2] * c.shape()[2];
|
||||
}
|
||||
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||
c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T1, typename T2, typename T3, typename U, typename Op>
|
||||
void ternary_op_dispatch_dims(
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
array& out,
|
||||
Op op) {
|
||||
switch (out.ndim()) {
|
||||
case 1:
|
||||
ternary_op_dims1<T1, T2, T3, U, Op>(a, b, c, out, op);
|
||||
return;
|
||||
case 2:
|
||||
ternary_op_dims2<T1, T2, T3, U, Op>(a, b, c, out, op);
|
||||
return;
|
||||
case 3:
|
||||
ternary_op_dims3<T1, T2, T3, U, Op>(a, b, c, out, op);
|
||||
return;
|
||||
case 4:
|
||||
ternary_op_dims4<T1, T2, T3, U, Op>(a, b, c, out, op);
|
||||
return;
|
||||
}
|
||||
|
||||
const T1* a_ptr = a.data<T1>();
|
||||
const T2* b_ptr = b.data<T2>();
|
||||
const T3* c_ptr = c.data<T3>();
|
||||
U* dst = out.data<U>();
|
||||
for (size_t i = 0; i < out.size(); i++) {
|
||||
int a_idx = elem_to_loc(i, a.shape(), a.strides());
|
||||
int b_idx = elem_to_loc(i, b.shape(), b.strides());
|
||||
int c_idx = elem_to_loc(i, c.shape(), c.strides());
|
||||
dst[i] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T1, typename T2, typename T3, typename U, typename Op>
|
||||
void ternary_op(
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
array& out,
|
||||
Op op) {
|
||||
TernaryOpType topt = get_ternary_op_type(a, b, c);
|
||||
set_ternary_op_output_data(a, b, c, out, topt);
|
||||
|
||||
// The full computation is scalar-scalar-scalar so we call the base op once.
|
||||
if (topt == TernaryOpType::ScalarScalarScalar) {
|
||||
*(out.data<U>()) = op(*a.data<T1>(), *b.data<T2>(), *c.data<T3>());
|
||||
return;
|
||||
}
|
||||
|
||||
ternary_op_dispatch_dims<T1, T2, T3, U>(a, b, c, out, op);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace mlx::core
|
@ -27,6 +27,7 @@ set(
|
||||
"scan"
|
||||
"softmax"
|
||||
"sort"
|
||||
"ternary"
|
||||
"unary"
|
||||
"gather"
|
||||
"scatter"
|
||||
|
@ -1,6 +1,7 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/binary.h"
|
||||
#include "mlx/backend/metal/kernels/ternary.h"
|
||||
#include "mlx/backend/metal/kernels/unary.h"
|
||||
|
||||
typedef half float16_t;
|
||||
|
10
mlx/backend/metal/kernels/ternary.h
Normal file
10
mlx/backend/metal/kernels/ternary.h
Normal file
@ -0,0 +1,10 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
struct Select {
|
||||
template <typename T>
|
||||
T operator()(bool condition, T x, T y) {
|
||||
return condition ? x : y;
|
||||
}
|
||||
};
|
184
mlx/backend/metal/kernels/ternary.metal
Normal file
184
mlx/backend/metal/kernels/ternary.metal
Normal file
@ -0,0 +1,184 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <metal_integer>
|
||||
#include <metal_math>
|
||||
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/ternary.h"
|
||||
|
||||
template <typename T, typename Op>
|
||||
[[kernel]] void ternary_op_g_nd1(
|
||||
device const bool* a,
|
||||
device const T* b,
|
||||
device const T* c,
|
||||
device T* d,
|
||||
constant const size_t& a_strides,
|
||||
constant const size_t& b_strides,
|
||||
constant const size_t& c_strides,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto a_idx = elem_to_loc_1(index, a_strides);
|
||||
auto b_idx = elem_to_loc_1(index, b_strides);
|
||||
auto c_idx = elem_to_loc_1(index, c_strides);
|
||||
d[index] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
[[kernel]] void ternary_op_g_nd2(
|
||||
device const bool* a,
|
||||
device const T* b,
|
||||
device const T* c,
|
||||
device T* d,
|
||||
constant const size_t a_strides[2],
|
||||
constant const size_t b_strides[2],
|
||||
constant const size_t c_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_2(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2(index, b_strides);
|
||||
auto c_idx = elem_to_loc_2(index, c_strides);
|
||||
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
|
||||
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
[[kernel]] void ternary_op_g_nd3(
|
||||
device const bool* a,
|
||||
device const T* b,
|
||||
device const T* c,
|
||||
device T* d,
|
||||
constant const size_t a_strides[3],
|
||||
constant const size_t b_strides[3],
|
||||
constant const size_t c_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_3(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3(index, b_strides);
|
||||
auto c_idx = elem_to_loc_3(index, c_strides);
|
||||
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename Op, int DIM>
|
||||
[[kernel]] void ternary_op_g_nd(
|
||||
device const bool* a,
|
||||
device const T* b,
|
||||
device const T* c,
|
||||
device T* d,
|
||||
constant const int shape[DIM],
|
||||
constant const size_t a_strides[DIM],
|
||||
constant const size_t b_strides[DIM],
|
||||
constant const size_t c_strides[DIM],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_3_nd<DIM>(index, shape, a_strides, b_strides, c_strides);
|
||||
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
[[kernel]] void ternary_op_g(
|
||||
device const bool* a,
|
||||
device const T* b,
|
||||
device const T* c,
|
||||
device T* d,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const size_t* c_strides,
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_3_nd(index, shape, a_strides, b_strides, c_strides, ndim);
|
||||
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
|
||||
d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
|
||||
}
|
||||
|
||||
#define instantiate_ternary_g(name, type, op) \
|
||||
template [[host_name(name)]] \
|
||||
[[kernel]] void ternary_op_g<type, op>( \
|
||||
device const bool* a, \
|
||||
device const type* b, \
|
||||
device const type* c, \
|
||||
device type* d, \
|
||||
constant const int* shape, \
|
||||
constant const size_t* a_strides, \
|
||||
constant const size_t* b_strides, \
|
||||
constant const size_t* c_strides, \
|
||||
constant const int& ndim, \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
|
||||
#define instantiate_ternary_g_dim(name, type, op, dims) \
|
||||
template [[host_name(name "_" #dims)]] \
|
||||
[[kernel]] void ternary_op_g_nd<type, op, dims>( \
|
||||
device const bool* a, \
|
||||
device const type* b, \
|
||||
device const type* c, \
|
||||
device type* d, \
|
||||
constant const int shape[dims], \
|
||||
constant const size_t a_strides[dims], \
|
||||
constant const size_t b_strides[dims], \
|
||||
constant const size_t c_strides[dims], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
|
||||
#define instantiate_ternary_g_nd(name, type, op) \
|
||||
template [[host_name(name "_1")]] \
|
||||
[[kernel]] void ternary_op_g_nd1<type, op>( \
|
||||
device const bool* a, \
|
||||
device const type* b, \
|
||||
device const type* c, \
|
||||
device type* d, \
|
||||
constant const size_t& a_strides, \
|
||||
constant const size_t& b_strides, \
|
||||
constant const size_t& c_strides, \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name(name "_2")]] \
|
||||
[[kernel]] void ternary_op_g_nd2<type, op>( \
|
||||
device const bool* a, \
|
||||
device const type* b, \
|
||||
device const type* c, \
|
||||
device type* d, \
|
||||
constant const size_t a_strides[2], \
|
||||
constant const size_t b_strides[2], \
|
||||
constant const size_t c_strides[2], \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name(name "_3")]] \
|
||||
[[kernel]] void ternary_op_g_nd3<type, op>( \
|
||||
device const bool* a, \
|
||||
device const type* b, \
|
||||
device const type* c, \
|
||||
device type* d, \
|
||||
constant const size_t a_strides[3], \
|
||||
constant const size_t b_strides[3], \
|
||||
constant const size_t c_strides[3], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
instantiate_ternary_g_dim(name, type, op, 4) \
|
||||
instantiate_ternary_g_dim(name, type, op, 5) \
|
||||
|
||||
#define instantiate_ternary_all(name, tname, type, op) \
|
||||
instantiate_ternary_g("g" #name #tname, type, op) \
|
||||
instantiate_ternary_g_nd("g" #name #tname, type, op) \
|
||||
|
||||
#define instantiate_ternary_float(name, op) \
|
||||
instantiate_ternary_all(name, float16, half, op) \
|
||||
instantiate_ternary_all(name, float32, float, op) \
|
||||
instantiate_ternary_all(name, bfloat16, bfloat16_t, op)
|
||||
|
||||
#define instantiate_ternary_types(name, op) \
|
||||
instantiate_ternary_all(name, bool_, bool, op) \
|
||||
instantiate_ternary_all(name, uint8, uint8_t, op) \
|
||||
instantiate_ternary_all(name, uint16, uint16_t, op) \
|
||||
instantiate_ternary_all(name, uint32, uint32_t, op) \
|
||||
instantiate_ternary_all(name, uint64, uint64_t, op) \
|
||||
instantiate_ternary_all(name, int8, int8_t, op) \
|
||||
instantiate_ternary_all(name, int16, int16_t, op) \
|
||||
instantiate_ternary_all(name, int32, int32_t, op) \
|
||||
instantiate_ternary_all(name, int64, int64_t, op) \
|
||||
instantiate_ternary_all(name, complex64, complex64_t, op) \
|
||||
instantiate_ternary_float(name, op)
|
||||
|
||||
instantiate_ternary_types(select, Select)
|
@ -91,6 +91,30 @@ inline size_t elem_to_loc(
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <int NDIM>
|
||||
inline uint3 elem_to_loc_3_nd(
|
||||
uint3 elem,
|
||||
constant const int shape[NDIM],
|
||||
constant const size_t a_strides[NDIM],
|
||||
constant const size_t b_strides[NDIM],
|
||||
constant const size_t c_strides[NDIM]) {
|
||||
uint3 loc = {
|
||||
static_cast<uint>(
|
||||
elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
|
||||
static_cast<uint>(
|
||||
elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2]),
|
||||
static_cast<uint>(
|
||||
elem.x * c_strides[NDIM - 1] + elem.y * c_strides[NDIM - 2])};
|
||||
for (int d = NDIM - 3; d >= 0; --d) {
|
||||
uint l = elem.z % shape[d];
|
||||
loc.x += l * a_strides[d];
|
||||
loc.y += l * b_strides[d];
|
||||
loc.z += l * c_strides[d];
|
||||
elem.z /= shape[d];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <int NDIM>
|
||||
inline uint2 elem_to_loc_2_nd(
|
||||
uint3 elem,
|
||||
@ -150,6 +174,30 @@ inline size_t elem_to_loc(
|
||||
return loc;
|
||||
}
|
||||
|
||||
inline uint3 elem_to_loc_3_nd(
|
||||
uint3 elem,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const size_t* c_strides,
|
||||
int ndim) {
|
||||
uint3 loc = {
|
||||
static_cast<uint>(
|
||||
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
|
||||
static_cast<uint>(
|
||||
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2]),
|
||||
static_cast<uint>(
|
||||
elem.x * c_strides[ndim - 1] + elem.y * c_strides[ndim - 2])};
|
||||
for (int d = ndim - 3; d >= 0; --d) {
|
||||
uint l = elem.z % shape[d];
|
||||
loc.x += l * a_strides[d];
|
||||
loc.y += l * b_strides[d];
|
||||
loc.z += l * c_strides[d];
|
||||
elem.z /= shape[d];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
inline uint2 elem_to_loc_2_nd(
|
||||
uint3 elem,
|
||||
constant const int* shape,
|
||||
|
@ -6,6 +6,7 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/common/ternary.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
@ -43,24 +44,25 @@ void binary_op(
|
||||
|
||||
std::ostringstream kname;
|
||||
switch (bopt) {
|
||||
case ScalarScalar:
|
||||
case BinaryOpType::ScalarScalar:
|
||||
kname << "ss";
|
||||
break;
|
||||
case ScalarVector:
|
||||
case BinaryOpType::ScalarVector:
|
||||
kname << "sv";
|
||||
break;
|
||||
case VectorScalar:
|
||||
case BinaryOpType::VectorScalar:
|
||||
kname << "vs";
|
||||
break;
|
||||
case VectorVector:
|
||||
case BinaryOpType::VectorVector:
|
||||
kname << "vv";
|
||||
break;
|
||||
case General:
|
||||
case BinaryOpType::General:
|
||||
kname << "g";
|
||||
break;
|
||||
}
|
||||
kname << op << type_to_name(a);
|
||||
if (bopt == General && shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
if (bopt == BinaryOpType::General &&
|
||||
shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
kname << "_" << shape.size();
|
||||
}
|
||||
|
||||
@ -80,7 +82,7 @@ void binary_op(
|
||||
set_array_buffer(compute_encoder, outputs[0], 2);
|
||||
set_array_buffer(compute_encoder, outputs[1], 3);
|
||||
|
||||
if (bopt == General) {
|
||||
if (bopt == BinaryOpType::General) {
|
||||
auto ndim = shape.size();
|
||||
if (ndim > 3) {
|
||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4);
|
||||
@ -141,24 +143,25 @@ void binary_op(
|
||||
|
||||
std::ostringstream kname;
|
||||
switch (bopt) {
|
||||
case ScalarScalar:
|
||||
case BinaryOpType::ScalarScalar:
|
||||
kname << "ss";
|
||||
break;
|
||||
case ScalarVector:
|
||||
case BinaryOpType::ScalarVector:
|
||||
kname << "sv";
|
||||
break;
|
||||
case VectorScalar:
|
||||
case BinaryOpType::VectorScalar:
|
||||
kname << "vs";
|
||||
break;
|
||||
case VectorVector:
|
||||
case BinaryOpType::VectorVector:
|
||||
kname << "vv";
|
||||
break;
|
||||
case General:
|
||||
case BinaryOpType::General:
|
||||
kname << "g";
|
||||
break;
|
||||
}
|
||||
kname << op << type_to_name(a);
|
||||
if (bopt == General && shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
if (bopt == BinaryOpType::General &&
|
||||
shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
kname << "_" << shape.size();
|
||||
}
|
||||
|
||||
@ -173,7 +176,7 @@ void binary_op(
|
||||
set_array_buffer(compute_encoder, donate_b ? out : b, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
|
||||
if (bopt == General) {
|
||||
if (bopt == BinaryOpType::General) {
|
||||
auto ndim = shape.size();
|
||||
if (ndim > 3) {
|
||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 3);
|
||||
@ -202,7 +205,8 @@ void binary_op(
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
// Launch a 1D grid of threads
|
||||
size_t nthreads = bopt == General ? out.size() : out.data_size();
|
||||
size_t nthreads =
|
||||
bopt == BinaryOpType::General ? out.size() : out.data_size();
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
@ -213,6 +217,86 @@ void binary_op(
|
||||
}
|
||||
}
|
||||
|
||||
void ternary_op(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op) {
|
||||
assert(inputs.size() == 3);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto& c = inputs[2];
|
||||
TernaryOpType topt = get_ternary_op_type(a, b, c);
|
||||
set_ternary_op_output_data(a, b, c, out, topt, true /* donate_with_move */);
|
||||
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Try to collapse contiguous dims
|
||||
auto [shape, strides] = collapse_contiguous_dims(a, b, c, out);
|
||||
auto& strides_a = strides[0];
|
||||
auto& strides_b = strides[1];
|
||||
auto& strides_c = strides[2];
|
||||
auto& strides_out = strides[3];
|
||||
|
||||
std::ostringstream kname;
|
||||
kname << "g";
|
||||
kname << op << type_to_name(b);
|
||||
if (topt == TernaryOpType::General &&
|
||||
shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
kname << "_" << shape.size();
|
||||
}
|
||||
|
||||
auto& s = out.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, a, 0);
|
||||
set_array_buffer(compute_encoder, b, 1);
|
||||
set_array_buffer(compute_encoder, c, 2);
|
||||
set_array_buffer(compute_encoder, out, 3);
|
||||
|
||||
auto ndim = shape.size();
|
||||
if (ndim > 3) {
|
||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4);
|
||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 7);
|
||||
|
||||
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 8);
|
||||
}
|
||||
} else if (ndim > 0) {
|
||||
// The shape is implicit in the grid for <= 3D
|
||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 6);
|
||||
} else {
|
||||
// For 0-dim we still need to bind something to these buffers since the
|
||||
// current ternary kernels always access the strides.
|
||||
size_t dummy_stride = 0;
|
||||
int dummy_shape = 0;
|
||||
compute_encoder->setBytes(&dummy_shape, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&dummy_stride, sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&dummy_stride, sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(&dummy_stride, sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 8);
|
||||
}
|
||||
|
||||
// Launch up to 3D grid of threads
|
||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||
size_t rest = out.size() / (dim0 * dim1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size != 1024) {
|
||||
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
|
||||
}
|
||||
MTL::Size group_dims = get_block_dims(dim0, dim1, rest);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void unary_op(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
@ -619,6 +703,10 @@ void Multiply::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "mul");
|
||||
}
|
||||
|
||||
void Select::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
ternary_op(inputs, out, "select");
|
||||
}
|
||||
|
||||
void Negative::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "neg");
|
||||
}
|
||||
|
@ -80,6 +80,7 @@ NO_GPU(Reshape)
|
||||
NO_GPU(Round)
|
||||
NO_GPU(Scan)
|
||||
NO_GPU(Scatter)
|
||||
NO_GPU(Select)
|
||||
NO_GPU(Sigmoid)
|
||||
NO_GPU(Sign)
|
||||
NO_GPU(Sin)
|
||||
|
@ -47,6 +47,10 @@ bool is_binary(const Primitive& p) {
|
||||
typeid(p) == typeid(Subtract));
|
||||
}
|
||||
|
||||
bool is_ternary(const Primitive& p) {
|
||||
return typeid(p) == typeid(Select);
|
||||
}
|
||||
|
||||
bool is_broadcast(const Primitive& p) {
|
||||
return typeid(p) == typeid(Broadcast);
|
||||
}
|
||||
@ -60,14 +64,16 @@ bool is_reduction(const Primitive& p) {
|
||||
}
|
||||
|
||||
bool is_fusable(const Primitive& p) {
|
||||
return is_unary(p) || is_binary(p) || is_broadcast(p) || is_noop(p);
|
||||
return is_unary(p) || is_binary(p) || is_ternary(p) || is_broadcast(p) ||
|
||||
is_noop(p);
|
||||
}
|
||||
|
||||
bool allows_shapeless(const Primitive& p) {
|
||||
return typeid(p) == typeid(Compiled) || is_unary(p) || is_binary(p) ||
|
||||
is_noop(p) || is_reduction(p) || typeid(p) == typeid(Softmax) ||
|
||||
typeid(p) == typeid(Sort) || typeid(p) == typeid(ArgSort) ||
|
||||
typeid(p) == typeid(ArgPartition) || typeid(p) == typeid(Partition);
|
||||
typeid(p) == typeid(ArgPartition) || typeid(p) == typeid(Partition) ||
|
||||
typeid(p) == typeid(Select);
|
||||
}
|
||||
|
||||
Compiled::Compiled(
|
||||
|
19
mlx/ops.cpp
19
mlx/ops.cpp
@ -1149,13 +1149,20 @@ array isneginf(const array& a, StreamOrDevice s /* = {} */) {
|
||||
}
|
||||
|
||||
array where(
|
||||
const array& condition,
|
||||
const array& x,
|
||||
const array& y,
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
// TODO, fix this to handle the NaN case when x has infs
|
||||
auto mask = astype(condition, bool_, s);
|
||||
return add(multiply(x, mask, s), multiply(y, logical_not(mask, s), s), s);
|
||||
auto condition = astype(a, bool_, s);
|
||||
Dtype out_dtype = promote_types(b.dtype(), c.dtype());
|
||||
auto inputs = broadcast_arrays(
|
||||
{condition, astype(b, out_dtype, s), astype(c, out_dtype, s)}, s);
|
||||
|
||||
return array(
|
||||
inputs[0].shape(),
|
||||
out_dtype,
|
||||
std::make_unique<Select>(to_stream(s)),
|
||||
inputs);
|
||||
}
|
||||
|
||||
array allclose(
|
||||
|
@ -48,6 +48,54 @@ std::tuple<array, array, int> vmap_binary_op(
|
||||
return {a, b, to_ax};
|
||||
}
|
||||
|
||||
std::tuple<array, array, array, int> vmap_ternary_op(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes,
|
||||
const Stream& stream) {
|
||||
assert(inputs.size() == 3);
|
||||
assert(axes.size() == 3);
|
||||
|
||||
auto a = inputs[0];
|
||||
auto b = inputs[1];
|
||||
auto c = inputs[2];
|
||||
int ndim = std::max(
|
||||
{a.ndim() + (axes[0] == -1),
|
||||
b.ndim() + (axes[1] == -1),
|
||||
c.ndim() + (axes[2] == -1)});
|
||||
|
||||
auto expand_dims = [stream, ndim](auto in) {
|
||||
auto shape = in.shape();
|
||||
shape.insert(shape.begin(), ndim - shape.size(), 1);
|
||||
return reshape(in, shape, stream);
|
||||
};
|
||||
|
||||
int to_ax = (ndim - a.ndim()) + axes[0];
|
||||
int from_ax1 = (ndim - b.ndim()) + axes[1];
|
||||
int from_ax2 = (ndim - c.ndim()) + axes[2];
|
||||
a = expand_dims(a);
|
||||
b = expand_dims(b);
|
||||
c = expand_dims(c);
|
||||
|
||||
auto find_tdims = [](auto x, int to_ax, int from_ax) {
|
||||
std::vector<int> tdims(x.ndim());
|
||||
std::iota(tdims.begin(), tdims.end(), 0);
|
||||
tdims.erase(tdims.begin() + from_ax);
|
||||
tdims.insert(tdims.begin() + to_ax, from_ax);
|
||||
return tdims;
|
||||
};
|
||||
|
||||
if (to_ax != from_ax1) {
|
||||
std::vector<int> tdims = find_tdims(b, to_ax, from_ax1);
|
||||
b = transpose(b, tdims, stream);
|
||||
}
|
||||
|
||||
if (to_ax != from_ax2) {
|
||||
std::vector<int> tdims = find_tdims(c, to_ax, from_ax2);
|
||||
c = transpose(c, tdims, stream);
|
||||
}
|
||||
return {a, b, c, to_ax};
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::vector<array> Primitive::jvp(
|
||||
@ -1775,6 +1823,76 @@ std::pair<std::vector<array>, std::vector<int>> Multiply::vmap(
|
||||
return {{multiply(a, b, stream())}, {to_ax}};
|
||||
}
|
||||
|
||||
std::vector<array> Select::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 3);
|
||||
assert(tangents.size() == 3);
|
||||
|
||||
auto jvp_fun = [&](int i) {
|
||||
int arg = argnums[i];
|
||||
|
||||
if (arg == 0) {
|
||||
return zeros_like(primals[0], stream());
|
||||
} else if (arg == 1) {
|
||||
return multiply(
|
||||
astype(primals[0], tangents[1].dtype(), stream()),
|
||||
tangents[1],
|
||||
stream());
|
||||
} else {
|
||||
return multiply(
|
||||
astype(
|
||||
logical_not(primals[0], stream()), tangents[2].dtype(), stream()),
|
||||
tangents[2],
|
||||
stream());
|
||||
}
|
||||
};
|
||||
|
||||
array jvp = jvp_fun(argnums[0]);
|
||||
for (int i = 1; i < argnums.size(); i++) {
|
||||
jvp = add(jvp, jvp_fun(argnums[i]));
|
||||
}
|
||||
return {jvp};
|
||||
}
|
||||
|
||||
std::vector<array> Select::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
assert(primals.size() == 3);
|
||||
assert(cotangents.size() == 1);
|
||||
|
||||
std::vector<array> vjps;
|
||||
for (auto arg : argnums) {
|
||||
if (arg == 0) {
|
||||
vjps.push_back(zeros_like(primals[0], stream()));
|
||||
} else if (arg == 1) {
|
||||
vjps.push_back(multiply(
|
||||
astype(primals[0], cotangents[0].dtype(), stream()),
|
||||
cotangents[0],
|
||||
stream()));
|
||||
} else if (arg == 2) {
|
||||
vjps.push_back(multiply(
|
||||
astype(
|
||||
logical_not(primals[0], stream()),
|
||||
cotangents[0].dtype(),
|
||||
stream()),
|
||||
cotangents[0],
|
||||
stream()));
|
||||
}
|
||||
}
|
||||
return vjps;
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Select::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
auto [a, b, c, to_ax] = vmap_ternary_op(inputs, axes, stream());
|
||||
return {{where(a, b, c, stream())}, {to_ax}};
|
||||
}
|
||||
|
||||
std::vector<array> Negative::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
|
@ -719,6 +719,23 @@ class DivMod : public Primitive {
|
||||
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
|
||||
};
|
||||
|
||||
class Select : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Select(Stream stream) : UnaryPrimitive(stream){};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Select)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Remainder : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Remainder(Stream stream) : UnaryPrimitive(stream){};
|
||||
|
@ -1075,6 +1075,37 @@ TEST_CASE("test jvp from vjp") {
|
||||
CHECK(compute_derivs(subtract));
|
||||
CHECK(compute_derivs(power));
|
||||
}
|
||||
|
||||
// Conditional selection element-wise op
|
||||
{
|
||||
auto condition = random::randint(0, 2, {5, 10});
|
||||
auto x = random::uniform({5, 10});
|
||||
auto y = random::uniform({5, 10});
|
||||
eval(condition, x, y);
|
||||
|
||||
auto compute_derivs = [&condition, &x, &y](auto fn) {
|
||||
auto fn_wrap = [&fn](std::vector<array> inputs) {
|
||||
return std::vector<array>{
|
||||
fn(inputs[0], inputs[1], inputs[2], default_device())};
|
||||
};
|
||||
|
||||
// Compute vjp and add results
|
||||
auto vjps = vjp(fn_wrap, {condition, x, y}, {ones(x.shape())}).second;
|
||||
auto vjp_out = add(add(vjps[0], vjps[1]), vjps[2]);
|
||||
|
||||
// Compute jvp
|
||||
array jvp_out =
|
||||
jvp(fn_wrap,
|
||||
{condition, x, y},
|
||||
{ones(condition.shape()), ones(y.shape()), ones(x.shape())})
|
||||
.second[0];
|
||||
|
||||
array result = array_equal(vjp_out, jvp_out);
|
||||
return result.item<bool>();
|
||||
};
|
||||
|
||||
CHECK(compute_derivs(where));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test complex gradients") {
|
||||
|
@ -2193,6 +2193,8 @@ TEST_CASE("test power") {
|
||||
}
|
||||
|
||||
TEST_CASE("test where") {
|
||||
const float inf = std::numeric_limits<float>::infinity();
|
||||
|
||||
array condition(true);
|
||||
array x(1.0f);
|
||||
array y(0.0f);
|
||||
@ -2224,6 +2226,49 @@ TEST_CASE("test where") {
|
||||
out = where(condition, x, y);
|
||||
expected = array({1, 2, 2, 1}, {2, 2});
|
||||
CHECK(array_equal(where(condition, x, y), expected).item<bool>());
|
||||
|
||||
condition = array(true);
|
||||
x = array({1, 2, 3});
|
||||
y = array({3, 6, 13});
|
||||
CHECK(array_equal(where(condition, x, y), array({1, 2, 3})).item<bool>());
|
||||
|
||||
condition = array(false);
|
||||
x = array({1, 2, 3});
|
||||
y = array({3, 6, 13});
|
||||
CHECK(array_equal(where(condition, x, y), array({3, 6, 13})).item<bool>());
|
||||
|
||||
condition = array({1, 1, 0});
|
||||
x = array({1, 2, 3});
|
||||
y = array({11, 12, 13});
|
||||
CHECK(array_equal(where(condition, x, y), array({1, 2, 13})).item<bool>());
|
||||
|
||||
condition = array({true, false}, {2, 1, 1});
|
||||
x = array({1, 2, 3, 4}, {2, 1, 2});
|
||||
y = array({11, 22, 33, 44}, {2, 2, 1});
|
||||
expected = array({1, 2, 1, 2, 33, 33, 44, 44}, {2, 2, 2});
|
||||
CHECK(array_equal(where(condition, x, y), expected).item<bool>());
|
||||
|
||||
condition = array({true, false, false});
|
||||
x = array({inf, 2.0, 3.0});
|
||||
y = array({10.0, 20.0, -inf});
|
||||
CHECK(array_equal(where(condition, x, y), array({inf, 20.0, -inf}))
|
||||
.item<bool>());
|
||||
|
||||
// 4-dim optimized case.
|
||||
condition = array({false});
|
||||
x = array({1, 2}, {2, 1, 1, 1});
|
||||
y = array({3, 4}, {1, 1, 2, 1});
|
||||
CHECK(array_equal(where(condition, x, y), array({3, 4, 3, 4}, {2, 1, 2, 1}))
|
||||
.item<bool>());
|
||||
|
||||
// 5-dim optimized case.
|
||||
condition = array({true, false}, {2, 1, 1, 1, 1});
|
||||
x = array({1, 2, 3, 4}, {2, 1, 1, 1, 2});
|
||||
y = array({11, 22}, {1, 1, 2, 1, 1});
|
||||
CHECK(array_equal(
|
||||
where(condition, x, y),
|
||||
array({1, 2, 1, 2, 11, 11, 22, 22}, {2, 1, 2, 1, 2}))
|
||||
.item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test stack") {
|
||||
|
@ -138,6 +138,70 @@ TEST_CASE("test simple vmap") {
|
||||
CHECK(array_equal(out, x + y).item<bool>());
|
||||
}
|
||||
|
||||
// vmap where (ternary op)
|
||||
{
|
||||
auto fun = [](std::vector<array> inputs) {
|
||||
auto out = where(inputs[0], inputs[1], inputs[2]);
|
||||
return std::vector<array>{out};
|
||||
};
|
||||
|
||||
auto vfun = vmap(fun);
|
||||
array cond({true, false}, {2, 1});
|
||||
array x({1.0, 2.0}, {2, 1});
|
||||
array y({2.0, 4.0}, {2, 1});
|
||||
auto out = vfun({cond, x, y})[0];
|
||||
CHECK(array_equal(out, array({1.0, 4.0}, {2, 1})).item<bool>());
|
||||
|
||||
cond = array({true, true, false}, {1, 3});
|
||||
x = ones({2, 1, 3});
|
||||
y = zeros({3, 2});
|
||||
vfun = vmap(fun, {1, 2, 0});
|
||||
out = vfun({cond, x, y})[0];
|
||||
|
||||
CHECK(
|
||||
array_equal(out, array({1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0}, {3, 2, 2}))
|
||||
.item<bool>());
|
||||
|
||||
vfun = vmap(fun, {1, 2, 0}, {1});
|
||||
out = vfun({cond, x, y})[0];
|
||||
CHECK(
|
||||
array_equal(out, array({1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0}, {2, 3, 2}))
|
||||
.item<bool>());
|
||||
|
||||
cond = array({true, false});
|
||||
x = array(2.);
|
||||
y = ones({3, 2});
|
||||
vfun = vmap(fun, {-1, -1, 0});
|
||||
out = vfun({cond, x, y})[0];
|
||||
CHECK(array_equal(out, array({2, 1, 2, 1, 2, 1}, {3, 2})).item<bool>());
|
||||
|
||||
cond = array({true, false});
|
||||
x = ones({3, 2});
|
||||
y = array(2.);
|
||||
vfun = vmap(fun, {-1, 0, -1});
|
||||
out = vfun({cond, x, y})[0];
|
||||
CHECK(array_equal(out, array({1, 2, 1, 2, 1, 2}, {3, 2})).item<bool>());
|
||||
|
||||
CHECK_THROWS_AS(vmap(fun, {-1, -1, -1}, {0}), std::invalid_argument);
|
||||
CHECK_THROWS_AS(vmap(fun, {-1, 0, -1}, {-1}), std::invalid_argument);
|
||||
CHECK_THROWS_AS(vmap(fun, {-1, -1, 0}, {-1}), std::invalid_argument);
|
||||
CHECK_THROWS_AS(vmap(fun, {0, -1, -1}, {-1}), std::invalid_argument);
|
||||
|
||||
cond = array({true, false});
|
||||
x = array(1.);
|
||||
y = array(2.);
|
||||
vfun = vmap(fun, {-1, -1, -1}, {-1});
|
||||
out = vfun({cond, x, y})[0];
|
||||
CHECK(array_equal(out, array({1.0, 2.0})).item<bool>());
|
||||
|
||||
cond = array({1, 1, 1, 0, 0, 0}, {3, 2, 1});
|
||||
x = ones({3, 2, 1});
|
||||
y = full({3, 2, 1}, 2);
|
||||
vfun = vmap(vmap(fun));
|
||||
out = vfun({cond, x, y})[0];
|
||||
CHECK(array_equal(out, array({1, 1, 1, 2, 2, 2}, {3, 2, 1})).item<bool>());
|
||||
}
|
||||
|
||||
// vmap with capturing closure
|
||||
{
|
||||
auto x = add(add(ones({2}), zeros({2})), zeros({2}));
|
||||
|
Loading…
Reference in New Issue
Block a user