From 126c9869c8005259b8511242b569d97d4b1d2b63 Mon Sep 17 00:00:00 2001 From: Rifur13 Date: Thu, 22 Feb 2024 18:10:48 -0500 Subject: [PATCH] Implement the 'where' primitive for conditional selection (#664) --- benchmarks/cpp/single_ops.cpp | 6 + mlx/backend/accelerate/primitives.cpp | 1 + mlx/backend/common/CMakeLists.txt | 1 + mlx/backend/common/binary.h | 44 ++-- mlx/backend/common/binary_two.h | 22 +- mlx/backend/common/default_primitives.cpp | 1 + mlx/backend/common/ops.h | 7 + mlx/backend/common/select.cpp | 72 ++++++ mlx/backend/common/ternary.h | 226 ++++++++++++++++++ mlx/backend/metal/kernels/CMakeLists.txt | 1 + mlx/backend/metal/kernels/compiled_preamble.h | 1 + mlx/backend/metal/kernels/ternary.h | 10 + mlx/backend/metal/kernels/ternary.metal | 184 ++++++++++++++ mlx/backend/metal/kernels/utils.h | 48 ++++ mlx/backend/metal/primitives.cpp | 118 +++++++-- mlx/backend/no_metal/primitives.cpp | 1 + mlx/compile.cpp | 10 +- mlx/ops.cpp | 19 +- mlx/primitives.cpp | 118 +++++++++ mlx/primitives.h | 17 ++ tests/autograd_tests.cpp | 31 +++ tests/ops_tests.cpp | 45 ++++ tests/vmap_tests.cpp | 64 +++++ 23 files changed, 991 insertions(+), 56 deletions(-) create mode 100644 mlx/backend/common/select.cpp create mode 100644 mlx/backend/common/ternary.h create mode 100644 mlx/backend/metal/kernels/ternary.h create mode 100644 mlx/backend/metal/kernels/ternary.metal diff --git a/benchmarks/cpp/single_ops.cpp b/benchmarks/cpp/single_ops.cpp index 69cba09e9..4505282f1 100644 --- a/benchmarks/cpp/single_ops.cpp +++ b/benchmarks/cpp/single_ops.cpp @@ -73,6 +73,7 @@ void time_unary_ops() { void time_binary_ops() { int M = 1000, N = 100, K = 10; + auto condition = random::randint(0, 2, {M, N, K}); auto a = random::uniform({M, N, K}); auto b = random::uniform({M, N, K}); auto device = default_device(); @@ -84,7 +85,9 @@ void time_binary_ops() { TIME(divide, a, b, device); TIME(maximum, a, b, device); TIME(minimum, a, b, device); + TIME(where, condition, a, b, device); + condition = array({true}); b = random::uniform({1}); eval(b); TIMEM("scalar", add, a, b, device); @@ -93,7 +96,9 @@ void time_binary_ops() { TIMEM("scalar", multiply, a, b, device); TIMEM("vector-scalar", divide, a, b, device); TIMEM("scalar-vector", divide, b, a, device); + TIMEM("scalar-vector", where, condition, a, b, device); + condition = broadcast_to(array({true}), {1000, 100}); a = broadcast_to(random::uniform({1}), {1000, 100}); b = broadcast_to(random::uniform({1}), {1000, 100}); eval(a, b); @@ -101,6 +106,7 @@ void time_binary_ops() { TIMEM("scalar-scalar broadcast", subtract, a, b, device); TIMEM("scalar-scalar broadcast", multiply, a, b, device); TIMEM("scalar-scalar broadcast", divide, a, b, device); + TIMEM("scalar-scalar broadcast", where, condition, a, b, device); } void time_strided_ops() { diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index e147b5888..1d4258f62 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -64,6 +64,7 @@ DEFAULT(Reshape) DEFAULT(Remainder) DEFAULT(Round) DEFAULT(Scatter) +DEFAULT(Select) DEFAULT(Sigmoid) DEFAULT(Sign) DEFAULT(Slice) diff --git a/mlx/backend/common/CMakeLists.txt b/mlx/backend/common/CMakeLists.txt index 38a9819e5..569d690ef 100644 --- a/mlx/backend/common/CMakeLists.txt +++ b/mlx/backend/common/CMakeLists.txt @@ -43,6 +43,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/select.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp ${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp diff --git a/mlx/backend/common/binary.h b/mlx/backend/common/binary.h index fb397b669..673d9cd14 100644 --- a/mlx/backend/common/binary.h +++ b/mlx/backend/common/binary.h @@ -9,7 +9,7 @@ namespace mlx::core { namespace { -enum BinaryOpType { +enum class BinaryOpType { ScalarScalar, ScalarVector, VectorScalar, @@ -20,17 +20,17 @@ enum BinaryOpType { BinaryOpType get_binary_op_type(const array& a, const array& b) { BinaryOpType bopt; if (a.data_size() == 1 && b.data_size() == 1) { - bopt = ScalarScalar; + bopt = BinaryOpType::ScalarScalar; } else if (a.data_size() == 1 && b.flags().contiguous) { - bopt = ScalarVector; + bopt = BinaryOpType::ScalarVector; } else if (b.data_size() == 1 && a.flags().contiguous) { - bopt = VectorScalar; + bopt = BinaryOpType::VectorScalar; } else if ( a.flags().row_contiguous && b.flags().row_contiguous || a.flags().col_contiguous && b.flags().col_contiguous) { - bopt = VectorVector; + bopt = BinaryOpType::VectorVector; } else { - bopt = General; + bopt = BinaryOpType::General; } return bopt; } @@ -42,11 +42,11 @@ void set_binary_op_output_data( BinaryOpType bopt, bool donate_with_move = false) { switch (bopt) { - case ScalarScalar: + case BinaryOpType::ScalarScalar: out.set_data( allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags()); break; - case ScalarVector: + case BinaryOpType::ScalarVector: if (b.is_donatable() && b.itemsize() == out.itemsize()) { if (donate_with_move) { out.move_shared_buffer(b); @@ -61,7 +61,7 @@ void set_binary_op_output_data( b.flags()); } break; - case VectorScalar: + case BinaryOpType::VectorScalar: if (a.is_donatable() && a.itemsize() == out.itemsize()) { if (donate_with_move) { out.move_shared_buffer(a); @@ -76,7 +76,7 @@ void set_binary_op_output_data( a.flags()); } break; - case VectorVector: + case BinaryOpType::VectorVector: if (a.is_donatable() && a.itemsize() == out.itemsize()) { if (donate_with_move) { out.move_shared_buffer(a); @@ -97,7 +97,7 @@ void set_binary_op_output_data( a.flags()); } break; - case General: + case BinaryOpType::General: if (a.is_donatable() && a.flags().row_contiguous && a.itemsize() == out.itemsize() && a.size() == out.size()) { if (donate_with_move) { @@ -424,25 +424,25 @@ void binary_op( set_binary_op_output_data(a, b, out, bopt); // The full computation is scalar scalar so call the base op once - if (bopt == ScalarScalar) { + if (bopt == BinaryOpType::ScalarScalar) { *(out.data()) = op(*a.data(), *b.data()); return; } // The full computation is scalar vector so delegate to the op - if (bopt == ScalarVector) { + if (bopt == BinaryOpType::ScalarVector) { opsv(a.data(), b.data(), out.data(), b.data_size()); return; } // The full computation is vector scalar so delegate to the op - if (bopt == VectorScalar) { + if (bopt == BinaryOpType::VectorScalar) { opvs(a.data(), b.data(), out.data(), a.data_size()); return; } // The full computation is vector vector so delegate to the op - if (bopt == VectorVector) { + if (bopt == BinaryOpType::VectorVector) { opvv(a.data(), b.data(), out.data(), out.size()); return; } @@ -475,17 +475,17 @@ void binary_op( // Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous int dim = ndim; if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) { - bopt = VectorVector; + bopt = BinaryOpType::VectorVector; dim = d; // Case 2: LxM and Fx1 where L and F are broadcastable and M is row // contiguous } else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) { - bopt = VectorScalar; + bopt = BinaryOpType::VectorScalar; dim = d; // Case 3: Lx1 and FxM where L and F are broadcastable and M is row // contiguous } else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) { - bopt = ScalarVector; + bopt = BinaryOpType::ScalarVector; dim = d; } @@ -495,20 +495,20 @@ void binary_op( size_t stride; if (dim == 0 || strides[dim - 1] < 16) { stride = 1; - bopt = General; + bopt = BinaryOpType::General; dim = ndim; } else { stride = strides[dim - 1]; } switch (bopt) { - case VectorVector: + case BinaryOpType::VectorVector: binary_op_dispatch_dims(a, b, out, opvv, dim, stride); break; - case VectorScalar: + case BinaryOpType::VectorScalar: binary_op_dispatch_dims(a, b, out, opvs, dim, stride); break; - case ScalarVector: + case BinaryOpType::ScalarVector: binary_op_dispatch_dims(a, b, out, opsv, dim, stride); break; default: diff --git a/mlx/backend/common/binary_two.h b/mlx/backend/common/binary_two.h index 3468cb61e..3ce2f7110 100644 --- a/mlx/backend/common/binary_two.h +++ b/mlx/backend/common/binary_two.h @@ -260,14 +260,14 @@ void binary_op( set_binary_op_output_data(a, b, out_b, bopt); // The full computation is scalar scalar so call the base op once - if (bopt == ScalarScalar) { + if (bopt == BinaryOpType::ScalarScalar) { std::tie(*(out_a.data()), *(out_b.data())) = op(*a.data(), *b.data()); return; } // The full computation is scalar vector so delegate to the op - if (bopt == ScalarVector) { + if (bopt == BinaryOpType::ScalarVector) { opsv( a.data(), b.data(), @@ -278,7 +278,7 @@ void binary_op( } // The full computation is vector scalar so delegate to the op - if (bopt == VectorScalar) { + if (bopt == BinaryOpType::VectorScalar) { opvs( a.data(), b.data(), @@ -289,7 +289,7 @@ void binary_op( } // The full computation is vector vector so delegate to the op - if (bopt == VectorVector) { + if (bopt == BinaryOpType::VectorVector) { opvv( a.data(), b.data(), @@ -327,17 +327,17 @@ void binary_op( // Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous int dim = ndim; if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) { - bopt = VectorVector; + bopt = BinaryOpType::VectorVector; dim = d; // Case 2: LxM and Fx1 where L and F are broadcastable and M is row // contiguous } else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) { - bopt = VectorScalar; + bopt = BinaryOpType::VectorScalar; dim = d; // Case 3: Lx1 and FxM where L and F are broadcastable and M is row // contiguous } else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) { - bopt = ScalarVector; + bopt = BinaryOpType::ScalarVector; dim = d; } @@ -347,20 +347,20 @@ void binary_op( size_t stride; if (dim == 0 || strides[dim - 1] < 16) { stride = 1; - bopt = General; + bopt = BinaryOpType::General; dim = ndim; } else { stride = strides[dim - 1]; } switch (bopt) { - case VectorVector: + case BinaryOpType::VectorVector: binary_op_dispatch_dims(a, b, out_a, out_b, opvv, dim, stride); break; - case VectorScalar: + case BinaryOpType::VectorScalar: binary_op_dispatch_dims(a, b, out_a, out_b, opvs, dim, stride); break; - case ScalarVector: + case BinaryOpType::ScalarVector: binary_op_dispatch_dims(a, b, out_a, out_b, opsv, dim, stride); break; default: diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index c65028d95..53b7a65f7 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -87,6 +87,7 @@ DEFAULT(Reshape) DEFAULT(Round) DEFAULT(Scan) DEFAULT(Scatter) +DEFAULT(Select) DEFAULT(Sigmoid) DEFAULT(Sign) DEFAULT(Sin) diff --git a/mlx/backend/common/ops.h b/mlx/backend/common/ops.h index 8b2d7ab58..560296622 100644 --- a/mlx/backend/common/ops.h +++ b/mlx/backend/common/ops.h @@ -588,4 +588,11 @@ struct LogicalOr { }; }; +struct Select { + template + T operator()(bool condition, T x, T y) { + return condition ? x : y; + } +}; + } // namespace mlx::core::detail diff --git a/mlx/backend/common/select.cpp b/mlx/backend/common/select.cpp new file mode 100644 index 000000000..1daa771b3 --- /dev/null +++ b/mlx/backend/common/select.cpp @@ -0,0 +1,72 @@ +// Copyright © 2023 Apple Inc. + +#include + +#include "mlx/backend/common/ternary.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +namespace { + +template +void select_op( + const array& a, + const array& b, + const array& c, + array& out, + Op op) { + switch (out.dtype()) { + case bool_: + ternary_op(a, b, c, out, op); + break; + case uint8: + ternary_op(a, b, c, out, op); + break; + case uint16: + ternary_op(a, b, c, out, op); + break; + case uint32: + ternary_op(a, b, c, out, op); + break; + case uint64: + ternary_op(a, b, c, out, op); + break; + case int8: + ternary_op(a, b, c, out, op); + break; + case int16: + ternary_op(a, b, c, out, op); + break; + case int32: + ternary_op(a, b, c, out, op); + break; + case int64: + ternary_op(a, b, c, out, op); + break; + case float16: + ternary_op(a, b, c, out, op); + break; + case float32: + ternary_op(a, b, c, out, op); + break; + case bfloat16: + ternary_op(a, b, c, out, op); + break; + case complex64: + ternary_op(a, b, c, out, op); + break; + } +} + +} // namespace + +void Select::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 3); + const auto& condition = inputs[0]; + const auto& a = inputs[1]; + const auto& b = inputs[2]; + select_op(condition, a, b, out, detail::Select()); +} + +} // namespace mlx::core diff --git a/mlx/backend/common/ternary.h b/mlx/backend/common/ternary.h new file mode 100644 index 000000000..52d202df7 --- /dev/null +++ b/mlx/backend/common/ternary.h @@ -0,0 +1,226 @@ +// Copyright © 2023 Apple Inc. + +#pragma once +#include "mlx/allocator.h" +#include "mlx/array.h" +#include "mlx/backend/common/ops.h" +#include "mlx/backend/common/utils.h" +namespace mlx::core { + +namespace { + +// TODO: Add support for more combinations of input types. +enum class TernaryOpType { + ScalarScalarScalar, + General, +}; + +TernaryOpType +get_ternary_op_type(const array& a, const array& b, const array& c) { + TernaryOpType topt; + if (a.data_size() == 1 && b.data_size() == 1 && c.data_size() == 1) { + topt = TernaryOpType::ScalarScalarScalar; + } else { + topt = TernaryOpType::General; + } + return topt; +} + +void set_ternary_op_output_data( + const array& a, + const array& b, + const array& c, + array& out, + TernaryOpType topt, + bool donate_with_move = false) { + switch (topt) { + case TernaryOpType::ScalarScalarScalar: + out.set_data( + allocator::malloc_or_wait(out.itemsize()), 1, b.strides(), b.flags()); + break; + case TernaryOpType::General: + out.set_data(allocator::malloc_or_wait(out.nbytes())); + break; + } +} + +template +void ternary_op_dims1( + const array& a, + const array& b, + const array& c, + array& out, + Op op) { + const T1* a_ptr = a.data(); + const T2* b_ptr = b.data(); + const T3* c_ptr = c.data(); + + U* dst = out.data(); + size_t a_idx = 0; + size_t b_idx = 0; + size_t c_idx = 0; + for (size_t i = 0; i < out.size(); ++i) { + dst[i] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]); + a_idx += a.strides()[0]; + b_idx += b.strides()[0]; + c_idx += c.strides()[0]; + } +} + +template +void ternary_op_dims2( + const array& a, + const array& b, + const array& c, + array& out, + Op op) { + const T1* a_ptr = a.data(); + const T2* b_ptr = b.data(); + const T3* c_ptr = c.data(); + + U* dst = out.data(); + size_t a_idx = 0; + size_t b_idx = 0; + size_t c_idx = 0; + size_t out_idx = 0; + for (size_t i = 0; i < a.shape()[0]; ++i) { + for (size_t j = 0; j < a.shape()[1]; ++j) { + dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]); + a_idx += a.strides()[1]; + b_idx += b.strides()[1]; + c_idx += c.strides()[1]; + } + a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1]; + b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1]; + c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1]; + } +} + +template +void ternary_op_dims3( + const array& a, + const array& b, + const array& c, + array& out, + Op op) { + const T1* a_ptr = a.data(); + const T2* b_ptr = b.data(); + const T3* c_ptr = c.data(); + U* dst = out.data(); + size_t a_idx = 0; + size_t b_idx = 0; + size_t c_idx = 0; + size_t out_idx = 0; + for (size_t i = 0; i < a.shape()[0]; ++i) { + for (size_t j = 0; j < a.shape()[1]; ++j) { + for (size_t k = 0; k < a.shape()[2]; ++k) { + dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]); + a_idx += a.strides()[2]; + b_idx += b.strides()[2]; + c_idx += c.strides()[2]; + } + a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2]; + b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2]; + c_idx += c.strides()[1] - c.strides()[2] * c.shape()[2]; + } + a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1]; + b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1]; + c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1]; + } +} + +template +void ternary_op_dims4( + const array& a, + const array& b, + const array& c, + array& out, + Op op) { + const T1* a_ptr = a.data(); + const T2* b_ptr = b.data(); + const T3* c_ptr = c.data(); + + U* dst = out.data(); + size_t a_idx = 0; + size_t b_idx = 0; + size_t c_idx = 0; + size_t out_idx = 0; + for (size_t i = 0; i < a.shape()[0]; ++i) { + for (size_t j = 0; j < a.shape()[1]; ++j) { + for (size_t k = 0; k < a.shape()[2]; ++k) { + for (size_t ii = 0; ii < a.shape()[3]; ++ii) { + dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]); + a_idx += a.strides()[3]; + b_idx += b.strides()[3]; + c_idx += c.strides()[3]; + } + a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3]; + b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3]; + c_idx += c.strides()[2] - c.strides()[3] * c.shape()[3]; + } + a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2]; + b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2]; + c_idx += c.strides()[1] - c.strides()[2] * c.shape()[2]; + } + a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1]; + b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1]; + c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1]; + } +} + +template +void ternary_op_dispatch_dims( + const array& a, + const array& b, + const array& c, + array& out, + Op op) { + switch (out.ndim()) { + case 1: + ternary_op_dims1(a, b, c, out, op); + return; + case 2: + ternary_op_dims2(a, b, c, out, op); + return; + case 3: + ternary_op_dims3(a, b, c, out, op); + return; + case 4: + ternary_op_dims4(a, b, c, out, op); + return; + } + + const T1* a_ptr = a.data(); + const T2* b_ptr = b.data(); + const T3* c_ptr = c.data(); + U* dst = out.data(); + for (size_t i = 0; i < out.size(); i++) { + int a_idx = elem_to_loc(i, a.shape(), a.strides()); + int b_idx = elem_to_loc(i, b.shape(), b.strides()); + int c_idx = elem_to_loc(i, c.shape(), c.strides()); + dst[i] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]); + } +} + +template +void ternary_op( + const array& a, + const array& b, + const array& c, + array& out, + Op op) { + TernaryOpType topt = get_ternary_op_type(a, b, c); + set_ternary_op_output_data(a, b, c, out, topt); + + // The full computation is scalar-scalar-scalar so we call the base op once. + if (topt == TernaryOpType::ScalarScalarScalar) { + *(out.data()) = op(*a.data(), *b.data(), *c.data()); + return; + } + + ternary_op_dispatch_dims(a, b, c, out, op); +} + +} // namespace + +} // namespace mlx::core diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index afd2fbc8a..2b97ff76c 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -27,6 +27,7 @@ set( "scan" "softmax" "sort" + "ternary" "unary" "gather" "scatter" diff --git a/mlx/backend/metal/kernels/compiled_preamble.h b/mlx/backend/metal/kernels/compiled_preamble.h index d5bf33696..12fdc8117 100644 --- a/mlx/backend/metal/kernels/compiled_preamble.h +++ b/mlx/backend/metal/kernels/compiled_preamble.h @@ -1,6 +1,7 @@ // Copyright © 2023-2024 Apple Inc. #include "mlx/backend/metal/kernels/binary.h" +#include "mlx/backend/metal/kernels/ternary.h" #include "mlx/backend/metal/kernels/unary.h" typedef half float16_t; diff --git a/mlx/backend/metal/kernels/ternary.h b/mlx/backend/metal/kernels/ternary.h new file mode 100644 index 000000000..e0235d9dd --- /dev/null +++ b/mlx/backend/metal/kernels/ternary.h @@ -0,0 +1,10 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +struct Select { + template + T operator()(bool condition, T x, T y) { + return condition ? x : y; + } +}; diff --git a/mlx/backend/metal/kernels/ternary.metal b/mlx/backend/metal/kernels/ternary.metal new file mode 100644 index 000000000..f3021fc11 --- /dev/null +++ b/mlx/backend/metal/kernels/ternary.metal @@ -0,0 +1,184 @@ +// Copyright © 2023 Apple Inc. + +#include +#include + +#include "mlx/backend/metal/kernels/utils.h" +#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/ternary.h" + +template +[[kernel]] void ternary_op_g_nd1( + device const bool* a, + device const T* b, + device const T* c, + device T* d, + constant const size_t& a_strides, + constant const size_t& b_strides, + constant const size_t& c_strides, + uint index [[thread_position_in_grid]]) { + auto a_idx = elem_to_loc_1(index, a_strides); + auto b_idx = elem_to_loc_1(index, b_strides); + auto c_idx = elem_to_loc_1(index, c_strides); + d[index] = Op()(a[a_idx], b[b_idx], c[c_idx]); +} + +template +[[kernel]] void ternary_op_g_nd2( + device const bool* a, + device const T* b, + device const T* c, + device T* d, + constant const size_t a_strides[2], + constant const size_t b_strides[2], + constant const size_t c_strides[2], + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + auto a_idx = elem_to_loc_2(index, a_strides); + auto b_idx = elem_to_loc_2(index, b_strides); + auto c_idx = elem_to_loc_2(index, c_strides); + size_t out_idx = index.x + (size_t)grid_dim.x * index.y; + d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]); +} + +template +[[kernel]] void ternary_op_g_nd3( + device const bool* a, + device const T* b, + device const T* c, + device T* d, + constant const size_t a_strides[3], + constant const size_t b_strides[3], + constant const size_t c_strides[3], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto a_idx = elem_to_loc_3(index, a_strides); + auto b_idx = elem_to_loc_3(index, b_strides); + auto c_idx = elem_to_loc_3(index, c_strides); + size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); + d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]); +} + +template +[[kernel]] void ternary_op_g_nd( + device const bool* a, + device const T* b, + device const T* c, + device T* d, + constant const int shape[DIM], + constant const size_t a_strides[DIM], + constant const size_t b_strides[DIM], + constant const size_t c_strides[DIM], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto idx = elem_to_loc_3_nd(index, shape, a_strides, b_strides, c_strides); + size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); + d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]); +} + +template +[[kernel]] void ternary_op_g( + device const bool* a, + device const T* b, + device const T* c, + device T* d, + constant const int* shape, + constant const size_t* a_strides, + constant const size_t* b_strides, + constant const size_t* c_strides, + constant const int& ndim, + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto idx = elem_to_loc_3_nd(index, shape, a_strides, b_strides, c_strides, ndim); + size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z); + d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]); +} + +#define instantiate_ternary_g(name, type, op) \ + template [[host_name(name)]] \ + [[kernel]] void ternary_op_g( \ + device const bool* a, \ + device const type* b, \ + device const type* c, \ + device type* d, \ + constant const int* shape, \ + constant const size_t* a_strides, \ + constant const size_t* b_strides, \ + constant const size_t* c_strides, \ + constant const int& ndim, \ + uint3 index [[thread_position_in_grid]], \ + uint3 grid_dim [[threads_per_grid]]); \ + +#define instantiate_ternary_g_dim(name, type, op, dims) \ + template [[host_name(name "_" #dims)]] \ + [[kernel]] void ternary_op_g_nd( \ + device const bool* a, \ + device const type* b, \ + device const type* c, \ + device type* d, \ + constant const int shape[dims], \ + constant const size_t a_strides[dims], \ + constant const size_t b_strides[dims], \ + constant const size_t c_strides[dims], \ + uint3 index [[thread_position_in_grid]], \ + uint3 grid_dim [[threads_per_grid]]); \ + +#define instantiate_ternary_g_nd(name, type, op) \ + template [[host_name(name "_1")]] \ + [[kernel]] void ternary_op_g_nd1( \ + device const bool* a, \ + device const type* b, \ + device const type* c, \ + device type* d, \ + constant const size_t& a_strides, \ + constant const size_t& b_strides, \ + constant const size_t& c_strides, \ + uint index [[thread_position_in_grid]]); \ + template [[host_name(name "_2")]] \ + [[kernel]] void ternary_op_g_nd2( \ + device const bool* a, \ + device const type* b, \ + device const type* c, \ + device type* d, \ + constant const size_t a_strides[2], \ + constant const size_t b_strides[2], \ + constant const size_t c_strides[2], \ + uint2 index [[thread_position_in_grid]], \ + uint2 grid_dim [[threads_per_grid]]); \ + template [[host_name(name "_3")]] \ + [[kernel]] void ternary_op_g_nd3( \ + device const bool* a, \ + device const type* b, \ + device const type* c, \ + device type* d, \ + constant const size_t a_strides[3], \ + constant const size_t b_strides[3], \ + constant const size_t c_strides[3], \ + uint3 index [[thread_position_in_grid]], \ + uint3 grid_dim [[threads_per_grid]]); \ + instantiate_ternary_g_dim(name, type, op, 4) \ + instantiate_ternary_g_dim(name, type, op, 5) \ + +#define instantiate_ternary_all(name, tname, type, op) \ + instantiate_ternary_g("g" #name #tname, type, op) \ + instantiate_ternary_g_nd("g" #name #tname, type, op) \ + +#define instantiate_ternary_float(name, op) \ + instantiate_ternary_all(name, float16, half, op) \ + instantiate_ternary_all(name, float32, float, op) \ + instantiate_ternary_all(name, bfloat16, bfloat16_t, op) + +#define instantiate_ternary_types(name, op) \ + instantiate_ternary_all(name, bool_, bool, op) \ + instantiate_ternary_all(name, uint8, uint8_t, op) \ + instantiate_ternary_all(name, uint16, uint16_t, op) \ + instantiate_ternary_all(name, uint32, uint32_t, op) \ + instantiate_ternary_all(name, uint64, uint64_t, op) \ + instantiate_ternary_all(name, int8, int8_t, op) \ + instantiate_ternary_all(name, int16, int16_t, op) \ + instantiate_ternary_all(name, int32, int32_t, op) \ + instantiate_ternary_all(name, int64, int64_t, op) \ + instantiate_ternary_all(name, complex64, complex64_t, op) \ + instantiate_ternary_float(name, op) + +instantiate_ternary_types(select, Select) diff --git a/mlx/backend/metal/kernels/utils.h b/mlx/backend/metal/kernels/utils.h index 8ef1127b6..9c3d20b30 100644 --- a/mlx/backend/metal/kernels/utils.h +++ b/mlx/backend/metal/kernels/utils.h @@ -91,6 +91,30 @@ inline size_t elem_to_loc( return loc; } +template +inline uint3 elem_to_loc_3_nd( + uint3 elem, + constant const int shape[NDIM], + constant const size_t a_strides[NDIM], + constant const size_t b_strides[NDIM], + constant const size_t c_strides[NDIM]) { + uint3 loc = { + static_cast( + elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]), + static_cast( + elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2]), + static_cast( + elem.x * c_strides[NDIM - 1] + elem.y * c_strides[NDIM - 2])}; + for (int d = NDIM - 3; d >= 0; --d) { + uint l = elem.z % shape[d]; + loc.x += l * a_strides[d]; + loc.y += l * b_strides[d]; + loc.z += l * c_strides[d]; + elem.z /= shape[d]; + } + return loc; +} + template inline uint2 elem_to_loc_2_nd( uint3 elem, @@ -150,6 +174,30 @@ inline size_t elem_to_loc( return loc; } +inline uint3 elem_to_loc_3_nd( + uint3 elem, + constant const int* shape, + constant const size_t* a_strides, + constant const size_t* b_strides, + constant const size_t* c_strides, + int ndim) { + uint3 loc = { + static_cast( + elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]), + static_cast( + elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2]), + static_cast( + elem.x * c_strides[ndim - 1] + elem.y * c_strides[ndim - 2])}; + for (int d = ndim - 3; d >= 0; --d) { + uint l = elem.z % shape[d]; + loc.x += l * a_strides[d]; + loc.y += l * b_strides[d]; + loc.z += l * c_strides[d]; + elem.z /= shape[d]; + } + return loc; +} + inline uint2 elem_to_loc_2_nd( uint3 elem, constant const int* shape, diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 056bbbc80..301adcdea 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -6,6 +6,7 @@ #include #include "mlx/backend/common/binary.h" +#include "mlx/backend/common/ternary.h" #include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels/defines.h" @@ -43,24 +44,25 @@ void binary_op( std::ostringstream kname; switch (bopt) { - case ScalarScalar: + case BinaryOpType::ScalarScalar: kname << "ss"; break; - case ScalarVector: + case BinaryOpType::ScalarVector: kname << "sv"; break; - case VectorScalar: + case BinaryOpType::VectorScalar: kname << "vs"; break; - case VectorVector: + case BinaryOpType::VectorVector: kname << "vv"; break; - case General: + case BinaryOpType::General: kname << "g"; break; } kname << op << type_to_name(a); - if (bopt == General && shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) { + if (bopt == BinaryOpType::General && + shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) { kname << "_" << shape.size(); } @@ -80,7 +82,7 @@ void binary_op( set_array_buffer(compute_encoder, outputs[0], 2); set_array_buffer(compute_encoder, outputs[1], 3); - if (bopt == General) { + if (bopt == BinaryOpType::General) { auto ndim = shape.size(); if (ndim > 3) { compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4); @@ -141,24 +143,25 @@ void binary_op( std::ostringstream kname; switch (bopt) { - case ScalarScalar: + case BinaryOpType::ScalarScalar: kname << "ss"; break; - case ScalarVector: + case BinaryOpType::ScalarVector: kname << "sv"; break; - case VectorScalar: + case BinaryOpType::VectorScalar: kname << "vs"; break; - case VectorVector: + case BinaryOpType::VectorVector: kname << "vv"; break; - case General: + case BinaryOpType::General: kname << "g"; break; } kname << op << type_to_name(a); - if (bopt == General && shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) { + if (bopt == BinaryOpType::General && + shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) { kname << "_" << shape.size(); } @@ -173,7 +176,7 @@ void binary_op( set_array_buffer(compute_encoder, donate_b ? out : b, 1); set_array_buffer(compute_encoder, out, 2); - if (bopt == General) { + if (bopt == BinaryOpType::General) { auto ndim = shape.size(); if (ndim > 3) { compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 3); @@ -202,7 +205,8 @@ void binary_op( compute_encoder->dispatchThreads(grid_dims, group_dims); } else { // Launch a 1D grid of threads - size_t nthreads = bopt == General ? out.size() : out.data_size(); + size_t nthreads = + bopt == BinaryOpType::General ? out.size() : out.data_size(); MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); if (thread_group_size > nthreads) { @@ -213,6 +217,86 @@ void binary_op( } } +void ternary_op( + const std::vector& inputs, + array& out, + const std::string op) { + assert(inputs.size() == 3); + auto& a = inputs[0]; + auto& b = inputs[1]; + auto& c = inputs[2]; + TernaryOpType topt = get_ternary_op_type(a, b, c); + set_ternary_op_output_data(a, b, c, out, topt, true /* donate_with_move */); + + if (out.size() == 0) { + return; + } + + // Try to collapse contiguous dims + auto [shape, strides] = collapse_contiguous_dims(a, b, c, out); + auto& strides_a = strides[0]; + auto& strides_b = strides[1]; + auto& strides_c = strides[2]; + auto& strides_out = strides[3]; + + std::ostringstream kname; + kname << "g"; + kname << op << type_to_name(b); + if (topt == TernaryOpType::General && + shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) { + kname << "_" << shape.size(); + } + + auto& s = out.primitive().stream(); + auto& d = metal::device(s.device); + auto kernel = d.get_kernel(kname.str()); + auto compute_encoder = d.get_command_encoder(s.index); + compute_encoder->setComputePipelineState(kernel); + set_array_buffer(compute_encoder, a, 0); + set_array_buffer(compute_encoder, b, 1); + set_array_buffer(compute_encoder, c, 2); + set_array_buffer(compute_encoder, out, 3); + + auto ndim = shape.size(); + if (ndim > 3) { + compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4); + compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5); + compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6); + compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 7); + + if (ndim > MAX_BINARY_SPECIALIZED_DIMS) { + compute_encoder->setBytes(&ndim, sizeof(int), 8); + } + } else if (ndim > 0) { + // The shape is implicit in the grid for <= 3D + compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4); + compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5); + compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 6); + } else { + // For 0-dim we still need to bind something to these buffers since the + // current ternary kernels always access the strides. + size_t dummy_stride = 0; + int dummy_shape = 0; + compute_encoder->setBytes(&dummy_shape, sizeof(int), 4); + compute_encoder->setBytes(&dummy_stride, sizeof(size_t), 5); + compute_encoder->setBytes(&dummy_stride, sizeof(size_t), 6); + compute_encoder->setBytes(&dummy_stride, sizeof(size_t), 7); + compute_encoder->setBytes(&ndim, sizeof(int), 8); + } + + // Launch up to 3D grid of threads + size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; + size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; + size_t rest = out.size() / (dim0 * dim1); + NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); + if (thread_group_size != 1024) { + throw std::runtime_error("[Metal::binary] Must use 1024 sized block"); + } + MTL::Size group_dims = get_block_dims(dim0, dim1, rest); + MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); + compute_encoder->dispatchThreads(grid_dims, group_dims); +} + void unary_op( const std::vector& inputs, array& out, @@ -619,6 +703,10 @@ void Multiply::eval_gpu(const std::vector& inputs, array& out) { binary_op(inputs, out, "mul"); } +void Select::eval_gpu(const std::vector& inputs, array& out) { + ternary_op(inputs, out, "select"); +} + void Negative::eval_gpu(const std::vector& inputs, array& out) { unary_op(inputs, out, "neg"); } diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index 8e66f56b3..4234eeb1c 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -80,6 +80,7 @@ NO_GPU(Reshape) NO_GPU(Round) NO_GPU(Scan) NO_GPU(Scatter) +NO_GPU(Select) NO_GPU(Sigmoid) NO_GPU(Sign) NO_GPU(Sin) diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 700c07ced..e54778fe1 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -47,6 +47,10 @@ bool is_binary(const Primitive& p) { typeid(p) == typeid(Subtract)); } +bool is_ternary(const Primitive& p) { + return typeid(p) == typeid(Select); +} + bool is_broadcast(const Primitive& p) { return typeid(p) == typeid(Broadcast); } @@ -60,14 +64,16 @@ bool is_reduction(const Primitive& p) { } bool is_fusable(const Primitive& p) { - return is_unary(p) || is_binary(p) || is_broadcast(p) || is_noop(p); + return is_unary(p) || is_binary(p) || is_ternary(p) || is_broadcast(p) || + is_noop(p); } bool allows_shapeless(const Primitive& p) { return typeid(p) == typeid(Compiled) || is_unary(p) || is_binary(p) || is_noop(p) || is_reduction(p) || typeid(p) == typeid(Softmax) || typeid(p) == typeid(Sort) || typeid(p) == typeid(ArgSort) || - typeid(p) == typeid(ArgPartition) || typeid(p) == typeid(Partition); + typeid(p) == typeid(ArgPartition) || typeid(p) == typeid(Partition) || + typeid(p) == typeid(Select); } Compiled::Compiled( diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 97d4a3a2d..a1d4c5b4e 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1149,13 +1149,20 @@ array isneginf(const array& a, StreamOrDevice s /* = {} */) { } array where( - const array& condition, - const array& x, - const array& y, + const array& a, + const array& b, + const array& c, StreamOrDevice s /* = {} */) { - // TODO, fix this to handle the NaN case when x has infs - auto mask = astype(condition, bool_, s); - return add(multiply(x, mask, s), multiply(y, logical_not(mask, s), s), s); + auto condition = astype(a, bool_, s); + Dtype out_dtype = promote_types(b.dtype(), c.dtype()); + auto inputs = broadcast_arrays( + {condition, astype(b, out_dtype, s), astype(c, out_dtype, s)}, s); + + return array( + inputs[0].shape(), + out_dtype, + std::make_unique