mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 11:14:32 +08:00
Implement the 'where' primitive for conditional selection (#664)
This commit is contained in:
@@ -27,6 +27,7 @@ set(
|
||||
"scan"
|
||||
"softmax"
|
||||
"sort"
|
||||
"ternary"
|
||||
"unary"
|
||||
"gather"
|
||||
"scatter"
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/binary.h"
|
||||
#include "mlx/backend/metal/kernels/ternary.h"
|
||||
#include "mlx/backend/metal/kernels/unary.h"
|
||||
|
||||
typedef half float16_t;
|
||||
|
10
mlx/backend/metal/kernels/ternary.h
Normal file
10
mlx/backend/metal/kernels/ternary.h
Normal file
@@ -0,0 +1,10 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
struct Select {
|
||||
template <typename T>
|
||||
T operator()(bool condition, T x, T y) {
|
||||
return condition ? x : y;
|
||||
}
|
||||
};
|
184
mlx/backend/metal/kernels/ternary.metal
Normal file
184
mlx/backend/metal/kernels/ternary.metal
Normal file
@@ -0,0 +1,184 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <metal_integer>
|
||||
#include <metal_math>
|
||||
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/ternary.h"
|
||||
|
||||
template <typename T, typename Op>
|
||||
[[kernel]] void ternary_op_g_nd1(
|
||||
device const bool* a,
|
||||
device const T* b,
|
||||
device const T* c,
|
||||
device T* d,
|
||||
constant const size_t& a_strides,
|
||||
constant const size_t& b_strides,
|
||||
constant const size_t& c_strides,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto a_idx = elem_to_loc_1(index, a_strides);
|
||||
auto b_idx = elem_to_loc_1(index, b_strides);
|
||||
auto c_idx = elem_to_loc_1(index, c_strides);
|
||||
d[index] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
[[kernel]] void ternary_op_g_nd2(
|
||||
device const bool* a,
|
||||
device const T* b,
|
||||
device const T* c,
|
||||
device T* d,
|
||||
constant const size_t a_strides[2],
|
||||
constant const size_t b_strides[2],
|
||||
constant const size_t c_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_2(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2(index, b_strides);
|
||||
auto c_idx = elem_to_loc_2(index, c_strides);
|
||||
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
|
||||
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
[[kernel]] void ternary_op_g_nd3(
|
||||
device const bool* a,
|
||||
device const T* b,
|
||||
device const T* c,
|
||||
device T* d,
|
||||
constant const size_t a_strides[3],
|
||||
constant const size_t b_strides[3],
|
||||
constant const size_t c_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_3(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3(index, b_strides);
|
||||
auto c_idx = elem_to_loc_3(index, c_strides);
|
||||
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename Op, int DIM>
|
||||
[[kernel]] void ternary_op_g_nd(
|
||||
device const bool* a,
|
||||
device const T* b,
|
||||
device const T* c,
|
||||
device T* d,
|
||||
constant const int shape[DIM],
|
||||
constant const size_t a_strides[DIM],
|
||||
constant const size_t b_strides[DIM],
|
||||
constant const size_t c_strides[DIM],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_3_nd<DIM>(index, shape, a_strides, b_strides, c_strides);
|
||||
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
[[kernel]] void ternary_op_g(
|
||||
device const bool* a,
|
||||
device const T* b,
|
||||
device const T* c,
|
||||
device T* d,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const size_t* c_strides,
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_3_nd(index, shape, a_strides, b_strides, c_strides, ndim);
|
||||
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
|
||||
d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
|
||||
}
|
||||
|
||||
#define instantiate_ternary_g(name, type, op) \
|
||||
template [[host_name(name)]] \
|
||||
[[kernel]] void ternary_op_g<type, op>( \
|
||||
device const bool* a, \
|
||||
device const type* b, \
|
||||
device const type* c, \
|
||||
device type* d, \
|
||||
constant const int* shape, \
|
||||
constant const size_t* a_strides, \
|
||||
constant const size_t* b_strides, \
|
||||
constant const size_t* c_strides, \
|
||||
constant const int& ndim, \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
|
||||
#define instantiate_ternary_g_dim(name, type, op, dims) \
|
||||
template [[host_name(name "_" #dims)]] \
|
||||
[[kernel]] void ternary_op_g_nd<type, op, dims>( \
|
||||
device const bool* a, \
|
||||
device const type* b, \
|
||||
device const type* c, \
|
||||
device type* d, \
|
||||
constant const int shape[dims], \
|
||||
constant const size_t a_strides[dims], \
|
||||
constant const size_t b_strides[dims], \
|
||||
constant const size_t c_strides[dims], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
|
||||
#define instantiate_ternary_g_nd(name, type, op) \
|
||||
template [[host_name(name "_1")]] \
|
||||
[[kernel]] void ternary_op_g_nd1<type, op>( \
|
||||
device const bool* a, \
|
||||
device const type* b, \
|
||||
device const type* c, \
|
||||
device type* d, \
|
||||
constant const size_t& a_strides, \
|
||||
constant const size_t& b_strides, \
|
||||
constant const size_t& c_strides, \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name(name "_2")]] \
|
||||
[[kernel]] void ternary_op_g_nd2<type, op>( \
|
||||
device const bool* a, \
|
||||
device const type* b, \
|
||||
device const type* c, \
|
||||
device type* d, \
|
||||
constant const size_t a_strides[2], \
|
||||
constant const size_t b_strides[2], \
|
||||
constant const size_t c_strides[2], \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name(name "_3")]] \
|
||||
[[kernel]] void ternary_op_g_nd3<type, op>( \
|
||||
device const bool* a, \
|
||||
device const type* b, \
|
||||
device const type* c, \
|
||||
device type* d, \
|
||||
constant const size_t a_strides[3], \
|
||||
constant const size_t b_strides[3], \
|
||||
constant const size_t c_strides[3], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
instantiate_ternary_g_dim(name, type, op, 4) \
|
||||
instantiate_ternary_g_dim(name, type, op, 5) \
|
||||
|
||||
#define instantiate_ternary_all(name, tname, type, op) \
|
||||
instantiate_ternary_g("g" #name #tname, type, op) \
|
||||
instantiate_ternary_g_nd("g" #name #tname, type, op) \
|
||||
|
||||
#define instantiate_ternary_float(name, op) \
|
||||
instantiate_ternary_all(name, float16, half, op) \
|
||||
instantiate_ternary_all(name, float32, float, op) \
|
||||
instantiate_ternary_all(name, bfloat16, bfloat16_t, op)
|
||||
|
||||
#define instantiate_ternary_types(name, op) \
|
||||
instantiate_ternary_all(name, bool_, bool, op) \
|
||||
instantiate_ternary_all(name, uint8, uint8_t, op) \
|
||||
instantiate_ternary_all(name, uint16, uint16_t, op) \
|
||||
instantiate_ternary_all(name, uint32, uint32_t, op) \
|
||||
instantiate_ternary_all(name, uint64, uint64_t, op) \
|
||||
instantiate_ternary_all(name, int8, int8_t, op) \
|
||||
instantiate_ternary_all(name, int16, int16_t, op) \
|
||||
instantiate_ternary_all(name, int32, int32_t, op) \
|
||||
instantiate_ternary_all(name, int64, int64_t, op) \
|
||||
instantiate_ternary_all(name, complex64, complex64_t, op) \
|
||||
instantiate_ternary_float(name, op)
|
||||
|
||||
instantiate_ternary_types(select, Select)
|
@@ -91,6 +91,30 @@ inline size_t elem_to_loc(
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <int NDIM>
|
||||
inline uint3 elem_to_loc_3_nd(
|
||||
uint3 elem,
|
||||
constant const int shape[NDIM],
|
||||
constant const size_t a_strides[NDIM],
|
||||
constant const size_t b_strides[NDIM],
|
||||
constant const size_t c_strides[NDIM]) {
|
||||
uint3 loc = {
|
||||
static_cast<uint>(
|
||||
elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
|
||||
static_cast<uint>(
|
||||
elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2]),
|
||||
static_cast<uint>(
|
||||
elem.x * c_strides[NDIM - 1] + elem.y * c_strides[NDIM - 2])};
|
||||
for (int d = NDIM - 3; d >= 0; --d) {
|
||||
uint l = elem.z % shape[d];
|
||||
loc.x += l * a_strides[d];
|
||||
loc.y += l * b_strides[d];
|
||||
loc.z += l * c_strides[d];
|
||||
elem.z /= shape[d];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <int NDIM>
|
||||
inline uint2 elem_to_loc_2_nd(
|
||||
uint3 elem,
|
||||
@@ -150,6 +174,30 @@ inline size_t elem_to_loc(
|
||||
return loc;
|
||||
}
|
||||
|
||||
inline uint3 elem_to_loc_3_nd(
|
||||
uint3 elem,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const size_t* c_strides,
|
||||
int ndim) {
|
||||
uint3 loc = {
|
||||
static_cast<uint>(
|
||||
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
|
||||
static_cast<uint>(
|
||||
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2]),
|
||||
static_cast<uint>(
|
||||
elem.x * c_strides[ndim - 1] + elem.y * c_strides[ndim - 2])};
|
||||
for (int d = ndim - 3; d >= 0; --d) {
|
||||
uint l = elem.z % shape[d];
|
||||
loc.x += l * a_strides[d];
|
||||
loc.y += l * b_strides[d];
|
||||
loc.z += l * c_strides[d];
|
||||
elem.z /= shape[d];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
inline uint2 elem_to_loc_2_nd(
|
||||
uint3 elem,
|
||||
constant const int* shape,
|
||||
|
@@ -6,6 +6,7 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/common/ternary.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
@@ -43,24 +44,25 @@ void binary_op(
|
||||
|
||||
std::ostringstream kname;
|
||||
switch (bopt) {
|
||||
case ScalarScalar:
|
||||
case BinaryOpType::ScalarScalar:
|
||||
kname << "ss";
|
||||
break;
|
||||
case ScalarVector:
|
||||
case BinaryOpType::ScalarVector:
|
||||
kname << "sv";
|
||||
break;
|
||||
case VectorScalar:
|
||||
case BinaryOpType::VectorScalar:
|
||||
kname << "vs";
|
||||
break;
|
||||
case VectorVector:
|
||||
case BinaryOpType::VectorVector:
|
||||
kname << "vv";
|
||||
break;
|
||||
case General:
|
||||
case BinaryOpType::General:
|
||||
kname << "g";
|
||||
break;
|
||||
}
|
||||
kname << op << type_to_name(a);
|
||||
if (bopt == General && shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
if (bopt == BinaryOpType::General &&
|
||||
shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
kname << "_" << shape.size();
|
||||
}
|
||||
|
||||
@@ -80,7 +82,7 @@ void binary_op(
|
||||
set_array_buffer(compute_encoder, outputs[0], 2);
|
||||
set_array_buffer(compute_encoder, outputs[1], 3);
|
||||
|
||||
if (bopt == General) {
|
||||
if (bopt == BinaryOpType::General) {
|
||||
auto ndim = shape.size();
|
||||
if (ndim > 3) {
|
||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4);
|
||||
@@ -141,24 +143,25 @@ void binary_op(
|
||||
|
||||
std::ostringstream kname;
|
||||
switch (bopt) {
|
||||
case ScalarScalar:
|
||||
case BinaryOpType::ScalarScalar:
|
||||
kname << "ss";
|
||||
break;
|
||||
case ScalarVector:
|
||||
case BinaryOpType::ScalarVector:
|
||||
kname << "sv";
|
||||
break;
|
||||
case VectorScalar:
|
||||
case BinaryOpType::VectorScalar:
|
||||
kname << "vs";
|
||||
break;
|
||||
case VectorVector:
|
||||
case BinaryOpType::VectorVector:
|
||||
kname << "vv";
|
||||
break;
|
||||
case General:
|
||||
case BinaryOpType::General:
|
||||
kname << "g";
|
||||
break;
|
||||
}
|
||||
kname << op << type_to_name(a);
|
||||
if (bopt == General && shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
if (bopt == BinaryOpType::General &&
|
||||
shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
kname << "_" << shape.size();
|
||||
}
|
||||
|
||||
@@ -173,7 +176,7 @@ void binary_op(
|
||||
set_array_buffer(compute_encoder, donate_b ? out : b, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
|
||||
if (bopt == General) {
|
||||
if (bopt == BinaryOpType::General) {
|
||||
auto ndim = shape.size();
|
||||
if (ndim > 3) {
|
||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 3);
|
||||
@@ -202,7 +205,8 @@ void binary_op(
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
// Launch a 1D grid of threads
|
||||
size_t nthreads = bopt == General ? out.size() : out.data_size();
|
||||
size_t nthreads =
|
||||
bopt == BinaryOpType::General ? out.size() : out.data_size();
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
@@ -213,6 +217,86 @@ void binary_op(
|
||||
}
|
||||
}
|
||||
|
||||
void ternary_op(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op) {
|
||||
assert(inputs.size() == 3);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto& c = inputs[2];
|
||||
TernaryOpType topt = get_ternary_op_type(a, b, c);
|
||||
set_ternary_op_output_data(a, b, c, out, topt, true /* donate_with_move */);
|
||||
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Try to collapse contiguous dims
|
||||
auto [shape, strides] = collapse_contiguous_dims(a, b, c, out);
|
||||
auto& strides_a = strides[0];
|
||||
auto& strides_b = strides[1];
|
||||
auto& strides_c = strides[2];
|
||||
auto& strides_out = strides[3];
|
||||
|
||||
std::ostringstream kname;
|
||||
kname << "g";
|
||||
kname << op << type_to_name(b);
|
||||
if (topt == TernaryOpType::General &&
|
||||
shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
kname << "_" << shape.size();
|
||||
}
|
||||
|
||||
auto& s = out.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, a, 0);
|
||||
set_array_buffer(compute_encoder, b, 1);
|
||||
set_array_buffer(compute_encoder, c, 2);
|
||||
set_array_buffer(compute_encoder, out, 3);
|
||||
|
||||
auto ndim = shape.size();
|
||||
if (ndim > 3) {
|
||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4);
|
||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 7);
|
||||
|
||||
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 8);
|
||||
}
|
||||
} else if (ndim > 0) {
|
||||
// The shape is implicit in the grid for <= 3D
|
||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 6);
|
||||
} else {
|
||||
// For 0-dim we still need to bind something to these buffers since the
|
||||
// current ternary kernels always access the strides.
|
||||
size_t dummy_stride = 0;
|
||||
int dummy_shape = 0;
|
||||
compute_encoder->setBytes(&dummy_shape, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&dummy_stride, sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&dummy_stride, sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(&dummy_stride, sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 8);
|
||||
}
|
||||
|
||||
// Launch up to 3D grid of threads
|
||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||
size_t rest = out.size() / (dim0 * dim1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size != 1024) {
|
||||
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
|
||||
}
|
||||
MTL::Size group_dims = get_block_dims(dim0, dim1, rest);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void unary_op(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
@@ -619,6 +703,10 @@ void Multiply::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "mul");
|
||||
}
|
||||
|
||||
void Select::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
ternary_op(inputs, out, "select");
|
||||
}
|
||||
|
||||
void Negative::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "neg");
|
||||
}
|
||||
|
Reference in New Issue
Block a user