More fixes for arrays with large sizes (#1405)

* compile works for big arrays when contiguous

* style

* nits in docs

* a bunch more stuff

* update jit

* update jit

* use constant for shapes and strides and remove elem_to_loc overload

* use kernel instantiation

* docs nits

* update binary and ternary

* comments
This commit is contained in:
Awni Hannun 2024-09-17 12:46:31 -07:00 committed by GitHub
parent c6739ba7f3
commit 4f46e9c997
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 325 additions and 611 deletions

View File

@ -13,6 +13,7 @@ simple functions.
:template: nn-module-template.rst :template: nn-module-template.rst
elu elu
celu
gelu gelu
gelu_approx gelu_approx
gelu_fast_approx gelu_fast_approx

View File

@ -13,6 +13,7 @@ Layers
AvgPool1d AvgPool1d
AvgPool2d AvgPool2d
BatchNorm BatchNorm
CELU
Conv1d Conv1d
Conv2d Conv2d
Conv3d Conv3d
@ -23,6 +24,7 @@ Layers
Dropout2d Dropout2d
Dropout3d Dropout3d
Embedding Embedding
ELU
GELU GELU
GLU GLU
GroupNorm GroupNorm
@ -34,6 +36,8 @@ Layers
LayerNorm LayerNorm
LeakyReLU LeakyReLU
Linear Linear
LogSigmoid
LogSoftmax
LSTM LSTM
MaxPool1d MaxPool1d
MaxPool2d MaxPool2d
@ -49,6 +53,7 @@ Layers
RoPE RoPE
SELU SELU
Sequential Sequential
Sigmoid
SiLU SiLU
SinusoidalPositionalEncoding SinusoidalPositionalEncoding
Softmin Softmin

View File

@ -51,6 +51,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp ${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp ${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp) ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp)
if(IOS) if(IOS)

View File

@ -0,0 +1,88 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/common/utils.h"
namespace mlx::core {
template <typename stride_t>
std::tuple<std::vector<int>, std::vector<std::vector<stride_t>>>
collapse_contiguous_dims_impl(
const std::vector<int>& shape,
const std::vector<std::vector<stride_t>>& strides,
stride_t size_cap) {
// Make a vector that has axes separated with -1. Collapse all axes between
// -1.
std::vector<int> to_collapse;
if (shape.size() > 0) {
if (shape[0] != 1) {
to_collapse.push_back(0);
}
size_t size = shape[0];
for (int i = 1; i < shape.size(); i++) {
bool contiguous = true;
size *= shape[i];
for (const std::vector<stride_t>& st : strides) {
if (st[i] * shape[i] != st[i - 1] || size > size_cap) {
contiguous = false;
size = shape[i];
break;
}
}
if (!contiguous) {
to_collapse.push_back(-1);
}
if (shape[i] != 1) {
to_collapse.push_back(i);
}
}
to_collapse.push_back(-1);
}
std::vector<int> out_shape;
std::vector<std::vector<stride_t>> out_strides(strides.size());
for (int i = 0;;) {
while (i < to_collapse.size() && to_collapse[i] == -1) {
++i;
};
if (i == to_collapse.size()) {
break;
}
int current_shape = shape[to_collapse[i]];
int k = i;
while (to_collapse[++k] != -1) {
current_shape *= shape[to_collapse[k]];
}
out_shape.push_back(current_shape);
for (int j = 0; j < strides.size(); j++) {
const std::vector<stride_t>& st = strides[j];
out_strides[j].push_back(st[to_collapse[k - 1]]);
}
i = k + 1;
}
if (!shape.empty() && out_shape.empty()) {
out_shape.push_back(1);
for (auto& out_stride : out_strides) {
out_stride.push_back(0);
}
}
return std::make_tuple(out_shape, out_strides);
}
std::tuple<std::vector<int>, std::vector<std::vector<int64_t>>>
collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<std::vector<int64_t>>& strides,
int64_t size_cap /* = std::numeric_limits<int32_t>::max() */) {
return collapse_contiguous_dims_impl(shape, strides, size_cap);
}
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<std::vector<size_t>>& strides,
size_t size_cap /* = std::numeric_limits<int32>::max() */) {
return collapse_contiguous_dims_impl(shape, strides, size_cap);
}
} // namespace mlx::core

View File

@ -44,58 +44,26 @@ std::vector<stride_t> make_contiguous_strides(const std::vector<int>& shape) {
// //
// When multiple arrays are passed they should all have the same shape. The // When multiple arrays are passed they should all have the same shape. The
// collapsed axes are also the same so one shape is returned. // collapsed axes are also the same so one shape is returned.
template <typename stride_t> std::tuple<std::vector<int>, std::vector<std::vector<int64_t>>>
inline std::tuple<std::vector<int>, std::vector<std::vector<stride_t>>>
collapse_contiguous_dims( collapse_contiguous_dims(
const std::vector<int>& shape, const std::vector<int>& shape,
const std::vector<std::vector<stride_t>> strides) { const std::vector<std::vector<int64_t>>& strides,
// Make a vector that has axes separated with -1. Collapse all axes between int64_t size_cap = std::numeric_limits<int32_t>::max());
// -1. std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
std::vector<int> to_collapse; collapse_contiguous_dims(
if (shape.size() > 0) { const std::vector<int>& shape,
to_collapse.push_back(0); const std::vector<std::vector<size_t>>& strides,
for (int i = 1; i < shape.size(); i++) { size_t size_cap = std::numeric_limits<int32_t>::max());
bool contiguous = true;
for (const std::vector<stride_t>& st : strides) {
if (st[i] * shape[i] != st[i - 1]) {
contiguous = false;
}
if (!contiguous) {
break;
}
}
if (!contiguous) {
to_collapse.push_back(-1);
}
to_collapse.push_back(i);
}
to_collapse.push_back(-1);
}
std::vector<int> out_shape;
std::vector<std::vector<stride_t>> out_strides(strides.size());
for (int i = 0; i < to_collapse.size(); i++) {
int current_shape = shape[to_collapse[i]];
while (to_collapse[++i] != -1) {
current_shape *= shape[to_collapse[i]];
}
out_shape.push_back(current_shape);
for (int j = 0; j < strides.size(); j++) {
const std::vector<stride_t>& st = strides[j];
out_strides[j].push_back(st[to_collapse[i - 1]]);
}
}
return std::make_tuple(out_shape, out_strides);
}
inline std::tuple<std::vector<int>, std::vector<std::vector<size_t>>> inline std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(const std::vector<array>& xs) { collapse_contiguous_dims(
const std::vector<array>& xs,
size_t size_cap = std::numeric_limits<size_t>::max()) {
std::vector<std::vector<size_t>> strides; std::vector<std::vector<size_t>> strides;
for (auto& x : xs) { for (auto& x : xs) {
strides.emplace_back(x.strides()); strides.emplace_back(x.strides());
} }
return collapse_contiguous_dims(xs[0].shape(), strides); return collapse_contiguous_dims(xs[0].shape(), strides, size_cap);
} }
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>> template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>

View File

@ -19,7 +19,7 @@
namespace mlx::core { namespace mlx::core {
constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5; constexpr int MAX_BINARY_SPECIALIZED_DIMS = 3;
std::string get_kernel_name( std::string get_kernel_name(
BinaryOpType bopt, BinaryOpType bopt,
@ -69,46 +69,61 @@ void binary_op_gpu_inplace(
} }
// Try to collapse contiguous dims // Try to collapse contiguous dims
auto [shape, strides] = collapse_contiguous_dims(a, b, out); auto maybe_collapse = [bopt, &a, &b, &out]() {
auto& strides_a = strides[0]; if (bopt == BinaryOpType::General) {
auto& strides_b = strides[1]; // The size cap here should ideally be `UINT32_MAX` but we are
auto& strides_out = strides[2]; // limitied by the shape being an int.
auto [shape, strides] = collapse_contiguous_dims(
{a, b, out},
/* size_cap = */ INT32_MAX);
return std::make_tuple(shape, strides[0], strides[1], strides[2]);
} else {
std::vector<size_t> e;
return std::make_tuple(std::vector<int>{}, e, e, e);
}
};
auto [shape, strides_a, strides_b, strides_out] = maybe_collapse();
bool use_2d = out.data_size() > UINT32_MAX; bool use_2d = out.data_size() > UINT32_MAX;
std::string kernel_name = get_kernel_name(bopt, op, a, use_2d, shape.size()); std::string kernel_name = get_kernel_name(bopt, op, a, use_2d, shape.size());
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
auto kernel = auto kernel = outputs.size() == 2
get_binary_two_kernel(d, kernel_name, a.dtype(), outputs[0].dtype(), op); ? get_binary_two_kernel(d, kernel_name, a.dtype(), out.dtype(), op)
: get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
// - If a is donated it goes to the first output // - If a is donated it goes to the first output
// - If b is donated it goes to the first output if a was not donated // - If b is donated it goes to the first output if a was not donated
// otherwise it goes to the second output // otherwise it goes to the second output.
// - If there is only one output only one of a and b will be donated.
bool donate_a = a.data_shared_ptr() == nullptr; bool donate_a = a.data_shared_ptr() == nullptr;
bool donate_b = b.data_shared_ptr() == nullptr; bool donate_b = b.data_shared_ptr() == nullptr;
compute_encoder.set_input_array(donate_a ? outputs[0] : a, 0); int arg_idx = 0;
compute_encoder.set_input_array(donate_a ? outputs[0] : a, arg_idx++);
compute_encoder.set_input_array( compute_encoder.set_input_array(
donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, 1); donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, arg_idx++);
compute_encoder.set_output_array(outputs[0], 2); compute_encoder.set_output_array(outputs[0], arg_idx++);
compute_encoder.set_output_array(outputs[1], 3); if (outputs.size() == 2) {
compute_encoder.set_output_array(outputs[1], arg_idx++);
}
if (bopt == BinaryOpType::General) { if (bopt == BinaryOpType::General) {
auto ndim = shape.size(); auto ndim = shape.size();
if (ndim > 3) { if (ndim > 3) {
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4); compute_encoder->setBytes(shape.data(), ndim * sizeof(int), arg_idx++);
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5); compute_encoder->setBytes(
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6); strides_a.data(), ndim * sizeof(size_t), arg_idx++);
compute_encoder->setBytes(
strides_b.data(), ndim * sizeof(size_t), arg_idx++);
compute_encoder->setBytes(&ndim, sizeof(int), arg_idx++);
} else { } else {
// The shape is implicit in the grid for <= 3D // The shape is implicit in the grid for <= 3D
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4); compute_encoder->setBytes(
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5); strides_a.data(), ndim * sizeof(size_t), arg_idx++);
} compute_encoder->setBytes(
strides_b.data(), ndim * sizeof(size_t), arg_idx++);
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
compute_encoder->setBytes(&ndim, sizeof(int), 7);
} }
// Launch up to 3D grid of threads // Launch up to 3D grid of threads
@ -125,9 +140,8 @@ void binary_op_gpu_inplace(
} else { } else {
// Launch a 1D or 2D grid of threads // Launch a 1D or 2D grid of threads
size_t nthreads = out.data_size(); size_t nthreads = out.data_size();
MTL::Size grid_dims = use_2d MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides()) : MTL::Size(nthreads, 1, 1);
: MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) { if (thread_group_size > nthreads) {
thread_group_size = nthreads; thread_group_size = nthreads;
@ -164,72 +178,8 @@ void binary_op_gpu_inplace(
array& out, array& out,
const std::string& op, const std::string& op,
const Stream& s) { const Stream& s) {
auto& a = inputs[0]; std::vector<array> outputs = {out};
auto& b = inputs[1]; binary_op_gpu_inplace(inputs, outputs, op, s);
auto bopt = get_binary_op_type(a, b);
if (out.size() == 0) {
return;
}
// Try to collapse contiguous dims
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
auto& strides_a = strides[0];
auto& strides_b = strides[1];
auto& strides_out = strides[2];
bool use_2d = out.data_size() > UINT32_MAX;
std::string kernel_name = get_kernel_name(bopt, op, a, use_2d, shape.size());
auto& d = metal::device(s.device);
auto kernel = get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
bool donate_a = a.data_shared_ptr() == nullptr;
bool donate_b = b.data_shared_ptr() == nullptr;
compute_encoder.set_input_array(donate_a ? out : a, 0);
compute_encoder.set_input_array(donate_b ? out : b, 1);
compute_encoder.set_output_array(out, 2);
if (bopt == BinaryOpType::General) {
auto ndim = shape.size();
if (ndim > 3) {
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 3);
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
} else {
// The shape is implicit in the grid for <= 3D
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 3);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 4);
}
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
compute_encoder->setBytes(&ndim, sizeof(int), 6);
}
// Launch up to 3D grid of threads
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t rest = out.size() / (dim0 * dim1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) {
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
}
auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
// Launch a 1D or 2D grid of threads
size_t nthreads = out.data_size();
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
} }
void binary_op_gpu( void binary_op_gpu(

View File

@ -22,7 +22,8 @@ inline void build_kernel(
const std::unordered_set<uintptr_t>& constant_ids, const std::unordered_set<uintptr_t>& constant_ids,
bool contiguous, bool contiguous,
int ndim, int ndim,
bool dynamic_dims) { bool dynamic_dims,
bool use_big_index = false) {
// All outputs should have the exact same shape and will be row contiguous // All outputs should have the exact same shape and will be row contiguous
auto output_shape = outputs[0].shape(); auto output_shape = outputs[0].shape();
auto output_strides = outputs[0].strides(); auto output_strides = outputs[0].strides();
@ -84,9 +85,15 @@ inline void build_kernel(
// The thread index in the whole grid // The thread index in the whole grid
os << " uint3 pos [[thread_position_in_grid]]," << std::endl os << " uint3 pos [[thread_position_in_grid]]," << std::endl
<< " uint3 grid [[threads_per_grid]]) {" << std::endl << " uint3 grid [[threads_per_grid]]) {" << std::endl;
<< " uint index = pos.x + grid.x * (pos.y + grid.y * pos.z);" if (use_big_index) {
<< std::endl; // This is only used for contiguous kernels which don't have
// a third grid dimension
os << " size_t index = pos.x + grid.x * size_t(pos.y);";
} else {
os << " uint index = pos.x + grid.x * (pos.y + grid.y * pos.z);";
}
os << std::endl;
// Extract the indices per axis to individual uints if we have arrays that // Extract the indices per axis to individual uints if we have arrays that
// are broadcasted or transposed // are broadcasted or transposed
@ -212,6 +219,17 @@ void Compiled::eval_gpu(
/* contiguous = */ true, /* contiguous = */ true,
/* ndim = */ 0, /* ndim = */ 0,
/* dynamic_dims = */ false); /* dynamic_dims = */ false);
build_kernel(
kernel,
kernel_lib_ + "_contiguous_big",
inputs_,
outputs_,
tape_,
constant_ids_,
/* contiguous = */ true,
/* ndim = */ 0,
/* dynamic_dims = */ false,
/* use_big_index = */ true);
for (int i = 1; i < 8; i++) { for (int i = 1; i < 8; i++) {
build_kernel( build_kernel(
kernel, kernel,
@ -285,7 +303,16 @@ void Compiled::eval_gpu(
initial_strides.push_back(std::move(xstrides)); initial_strides.push_back(std::move(xstrides));
} }
std::tie(shape, strides) = std::tie(shape, strides) =
collapse_contiguous_dims(output_shape, initial_strides); collapse_contiguous_dims(output_shape, initial_strides, INT32_MAX);
}
bool use_2d = false;
if (contiguous) {
size_t max_size = 0;
for (auto& in : inputs) {
max_size = std::max(max_size, in.data_size());
}
use_2d = (max_size > UINT32_MAX);
} }
// Get the kernel from the lib // Get the kernel from the lib
@ -298,6 +325,8 @@ void Compiled::eval_gpu(
} else { } else {
kernel_name += std::to_string(shape.size()); kernel_name += std::to_string(shape.size());
} }
} else if (use_2d) {
kernel_name += "_big";
} }
auto kernel = d.get_kernel(kernel_name, lib); auto kernel = d.get_kernel(kernel_name, lib);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
@ -348,8 +377,10 @@ void Compiled::eval_gpu(
// Launch the kernel // Launch the kernel
if (contiguous) { if (contiguous) {
size_t nthreads = outputs[0].size(); size_t nthreads = outputs[0].data_size();
MTL::Size grid_dims(nthreads, 1, 1); MTL::Size grid_dims = use_2d
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides())
: MTL::Size(nthreads, 1, 1);
MTL::Size group_dims( MTL::Size group_dims(
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1); std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);

View File

@ -10,7 +10,7 @@
namespace mlx::core { namespace mlx::core {
constexpr int MAX_COPY_SPECIALIZED_DIMS = 5; constexpr int MAX_COPY_SPECIALIZED_DIMS = 3;
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) { void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
if (ctype == CopyType::Vector) { if (ctype == CopyType::Vector) {
@ -59,10 +59,20 @@ void copy_gpu_inplace(
} }
// Try to collapse contiguous dims // Try to collapse contiguous dims
auto [shape, strides] = collapse_contiguous_dims( auto maybe_collapse =
data_shape, std::vector{strides_in_pre, strides_out_pre}); [ctype, &data_shape, &strides_in_pre, &strides_out_pre]() {
auto& strides_in_ = strides[0]; if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
auto& strides_out_ = strides[1]; auto [shape, strides] = collapse_contiguous_dims(
data_shape,
std::vector{strides_in_pre, strides_out_pre},
/* size_cap = */ INT32_MAX);
return std::make_tuple(shape, strides[0], strides[1]);
} else {
std::vector<stride_t> e;
return std::make_tuple(std::vector<int>{}, e, e);
}
};
auto [shape, strides_in_, strides_out_] = maybe_collapse();
bool use_2d = out.data_size() > UINT32_MAX; bool use_2d = out.data_size() > UINT32_MAX;
auto& d = metal::device(s.device); auto& d = metal::device(s.device);

View File

@ -1,100 +0,0 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view copy_kernels = R"(
template [[host_name("s_{0}")]] [[kernel]] void copy_s<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
uint index [[thread_position_in_grid]]);
template [[host_name("v_{0}")]] [[kernel]] void copy_v<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
uint index [[thread_position_in_grid]]);
template [[host_name("g4_{0}")]] [[kernel]] void
copy_g_nd<{1}, {2}, 4>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("gg4_{0}")]] [[kernel]] void
copy_gg_nd<{1}, {2}, 4>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]);
template [[host_name("g5_{0}")]] [[kernel]] void
copy_g_nd<{1}, {2}, 5>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("gg5_{0}")]] [[kernel]] void
copy_gg_nd<{1}, {2}, 5>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]);
template [[host_name("g1_{0}")]] [[kernel]] void copy_g_nd1<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]],
uint index [[thread_position_in_grid]]);
template [[host_name("g2_{0}")]] [[kernel]] void copy_g_nd2<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]);
template [[host_name("g3_{0}")]] [[kernel]] void copy_g_nd3<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("gg1_{0}")]] [[kernel]] void
copy_gg_nd1<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]],
constant const int64_t& dst_stride [[buffer(4)]],
uint index [[thread_position_in_grid]]);
template [[host_name("gg2_{0}")]] [[kernel]] void
copy_gg_nd2<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint2 index [[thread_position_in_grid]]);
template [[host_name("gg3_{0}")]] [[kernel]] void
copy_gg_nd3<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]);
template [[host_name("g_{0}")]] [[kernel]] void copy_g<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("gg_{0}")]] [[kernel]] void copy_gg<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]]);
)";

View File

@ -1,9 +1,7 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include <map>
#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/jit/arange.h" #include "mlx/backend/metal/jit/arange.h"
#include "mlx/backend/metal/jit/copy.h"
#include "mlx/backend/metal/jit/gemv_masked.h" #include "mlx/backend/metal/jit/gemv_masked.h"
#include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/jit/scan.h" #include "mlx/backend/metal/jit/scan.h"
@ -67,7 +65,7 @@ void add_binary_kernels(
Dtype out_type, Dtype out_type,
const std::string op, const std::string op,
std::ostringstream& kernel_source) { std::ostringstream& kernel_source) {
const std::map<std::string, std::string> kernel_types = { const std::array<std::pair<std::string, std::string>, 11> kernel_types = {{
{"ss", "binary_ss"}, {"ss", "binary_ss"},
{"vs", "binary_vs"}, {"vs", "binary_vs"},
{"sv", "binary_sv"}, {"sv", "binary_sv"},
@ -78,29 +76,16 @@ void add_binary_kernels(
{"g1", "binary_g_nd1"}, {"g1", "binary_g_nd1"},
{"g2", "binary_g_nd2"}, {"g2", "binary_g_nd2"},
{"g3", "binary_g_nd3"}, {"g3", "binary_g_nd3"},
{"g4", "binary_g_nd"},
{"g5", "binary_g_nd"},
{"gn", "binary_g"}, {"gn", "binary_g"},
}; }};
for (auto [name, func] : kernel_types) { for (auto& [name, func] : kernel_types) {
std::string template_def; std::string template_def;
if (name == "g4" || name == "g5") { template_def = get_template_definition(
int dim = std::stoi(name.substr(1)); name + lib_name,
template_def = get_template_definition( func,
name + lib_name, get_type_string(in_type),
func, get_type_string(out_type),
get_type_string(in_type), op);
get_type_string(out_type),
op,
dim);
} else {
template_def = get_template_definition(
name + lib_name,
func,
get_type_string(in_type),
get_type_string(out_type),
op);
}
kernel_source << template_def; kernel_source << template_def;
} }
} }
@ -149,27 +134,19 @@ MTL::ComputePipelineState* get_ternary_kernel(
auto lib = d.get_library(lib_name); auto lib = d.get_library(lib_name);
if (lib == nullptr) { if (lib == nullptr) {
std::ostringstream kernel_source; std::ostringstream kernel_source;
const std::map<std::string, std::string> kernel_types = { const std::array<std::pair<std::string, std::string>, 6> kernel_types = {{
{"v", "ternary_v"}, {"v", "ternary_v"},
{"v2", "ternary_v2"}, {"v2", "ternary_v2"},
{"g", "ternary_g"}, {"g", "ternary_g"},
{"g1", "ternary_g_nd1"}, {"g1", "ternary_g_nd1"},
{"g2", "ternary_g_nd2"}, {"g2", "ternary_g_nd2"},
{"g3", "ternary_g_nd3"}, {"g3", "ternary_g_nd3"},
{"g4", "ternary_g_nd"}, }};
{"g5", "ternary_g_nd"},
};
kernel_source << metal::utils() << metal::ternary_ops() << metal::ternary(); kernel_source << metal::utils() << metal::ternary_ops() << metal::ternary();
for (auto [name, func] : kernel_types) { for (auto& [name, func] : kernel_types) {
std::string template_def; std::string template_def;
if (name == "g4" || name == "g5") { template_def = get_template_definition(
int dim = std::stoi(name.substr(1)); name + "_" + lib_name, func, get_type_string(type), op);
template_def = get_template_definition(
name + "_" + lib_name, func, get_type_string(type), op, dim);
} else {
template_def = get_template_definition(
name + "_" + lib_name, func, get_type_string(type), op);
}
kernel_source << template_def; kernel_source << template_def;
} }
lib = d.get_library(lib_name, kernel_source.str()); lib = d.get_library(lib_name, kernel_source.str());
@ -186,12 +163,27 @@ MTL::ComputePipelineState* get_copy_kernel(
auto lib = d.get_library(lib_name); auto lib = d.get_library(lib_name);
if (lib == nullptr) { if (lib == nullptr) {
std::ostringstream kernel_source; std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::copy() auto in_type = get_type_string(in.dtype());
<< fmt::format( auto out_type = get_type_string(out.dtype());
copy_kernels, kernel_source
lib_name, << metal::utils() << metal::copy()
get_type_string(in.dtype()), << get_template_definition("s_" + lib_name, "copy_s", in_type, out_type)
get_type_string(out.dtype())); << get_template_definition("v_" + lib_name, "copy_v", in_type, out_type)
<< get_template_definition(
"g1_" + lib_name, "copy_g_nd1", in_type, out_type)
<< get_template_definition(
"g2_" + lib_name, "copy_g_nd2", in_type, out_type)
<< get_template_definition(
"g3_" + lib_name, "copy_g_nd3", in_type, out_type)
<< get_template_definition("g_" + lib_name, "copy_g", in_type, out_type)
<< get_template_definition(
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type)
<< get_template_definition(
"gg2_" + lib_name, "copy_gg_nd2", in_type, out_type)
<< get_template_definition(
"gg3_" + lib_name, "copy_gg_nd3", in_type, out_type)
<< get_template_definition(
"gg_" + lib_name, "copy_gg", in_type, out_type);
lib = d.get_library(lib_name, kernel_source.str()); lib = d.get_library(lib_name, kernel_source.str());
} }
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
@ -296,11 +288,11 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
if (lib == nullptr) { if (lib == nullptr) {
std::ostringstream kernel_source; std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::sort(); kernel_source << metal::utils() << metal::sort();
std::vector<std::pair<std::string, std::string>> kernel_types = { std::array<std::pair<std::string, std::string>, 3> kernel_types = {
{"sort_", "mb_block_sort"}, {{"sort_", "mb_block_sort"},
{"partition_", "mb_block_partition"}, {"partition_", "mb_block_partition"},
{"merge_", "mb_block_merge"}}; {"merge_", "mb_block_merge"}}};
for (auto [name, func] : kernel_types) { for (auto& [name, func] : kernel_types) {
kernel_source << get_template_definition( kernel_source << get_template_definition(
name + lib_name, name + lib_name,
func, func,

View File

@ -70,16 +70,16 @@ IndexValPair<U> simd_shuffle_down(IndexValPair<U> data, uint16_t delta) {
simd_shuffle_down(data.index, delta), simd_shuffle_down(data.val, delta)}; simd_shuffle_down(data.index, delta), simd_shuffle_down(data.val, delta)};
} }
template <typename T, typename Op, int N_READS> template <typename T, typename Op, int N_READS = 4>
[[kernel]] void arg_reduce_general( [[kernel]] void arg_reduce_general(
const device T* in [[buffer(0)]], const device T* in [[buffer(0)]],
device uint32_t* out [[buffer(1)]], device uint32_t* out [[buffer(1)]],
const device int* shape [[buffer(2)]], const constant int* shape [[buffer(2)]],
const device size_t* in_strides [[buffer(3)]], const constant size_t* in_strides [[buffer(3)]],
const device size_t* out_strides [[buffer(4)]], const constant size_t* out_strides [[buffer(4)]],
const device size_t& ndim [[buffer(5)]], const constant size_t& ndim [[buffer(5)]],
const device size_t& axis_stride [[buffer(6)]], const constant size_t& axis_stride [[buffer(6)]],
const device size_t& axis_size [[buffer(7)]], const constant size_t& axis_size [[buffer(7)]],
uint gid [[thread_position_in_grid]], uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]], uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]], uint lsize [[threads_per_threadgroup]],
@ -159,28 +159,12 @@ template <typename T, typename Op, int N_READS>
} }
} }
#define instantiate_arg_reduce_helper(name, itype, op) \
template [[host_name(name)]] [[kernel]] void \
arg_reduce_general<itype, op<itype>, 4>( \
const device itype* in [[buffer(0)]], \
device uint32_t* out [[buffer(1)]], \
const device int* shape [[buffer(2)]], \
const device size_t* in_strides [[buffer(3)]], \
const device size_t* out_strides [[buffer(4)]], \
const device size_t& ndim [[buffer(5)]], \
const device size_t& axis_stride [[buffer(6)]], \
const device size_t& axis_size [[buffer(7)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_size [[threads_per_simdgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
// clang-format off // clang-format off
#define instantiate_arg_reduce(name, itype) \ #define instantiate_arg_reduce(name, itype) \
instantiate_arg_reduce_helper("argmin_" #name , itype, ArgMin) \ instantiate_kernel( \
instantiate_arg_reduce_helper("argmax_" #name , itype, ArgMax) "argmin_" #name, arg_reduce_general, itype, ArgMin<itype>) \
instantiate_kernel( \
"argmax_" #name, arg_reduce_general, itype, ArgMax<itype>)
instantiate_arg_reduce(bool_, bool) instantiate_arg_reduce(bool_, bool)
instantiate_arg_reduce(uint8, uint8_t) instantiate_arg_reduce(uint8, uint8_t)

View File

@ -93,7 +93,7 @@ template <typename T, typename U, typename Op>
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2(index, a_strides); auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2(index, b_strides); auto b_idx = elem_to_loc_2(index, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * index.y; size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
c[out_idx] = Op()(a[a_idx], b[b_idx]); c[out_idx] = Op()(a[a_idx], b[b_idx]);
} }
@ -109,26 +109,10 @@ template <typename T, typename U, typename Op>
auto a_idx = elem_to_loc_3(index, a_strides); auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides); auto b_idx = elem_to_loc_3(index, b_strides);
size_t out_idx = size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
c[out_idx] = Op()(a[a_idx], b[b_idx]); c[out_idx] = Op()(a[a_idx], b[b_idx]);
} }
template <typename T, typename U, typename Op, int DIM>
[[kernel]] void binary_g_nd(
device const T* a,
device const T* b,
device U* c,
constant const int shape[DIM],
constant const size_t a_strides[DIM],
constant const size_t b_strides[DIM],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op()(a[idx.x], b[idx.y]);
}
template <typename T, typename U, typename Op> template <typename T, typename U, typename Op>
[[kernel]] void binary_g( [[kernel]] void binary_g(
device const T* a, device const T* a,
@ -141,6 +125,7 @@ template <typename T, typename U, typename Op>
uint3 index [[thread_position_in_grid]], uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) { uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim); auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z); size_t out_idx =
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
c[out_idx] = Op()(a[idx.x], b[idx.y]); c[out_idx] = Op()(a[idx.x], b[idx.y]);
} }

View File

@ -21,8 +21,6 @@
instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op) \ instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op) \ instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3" #op #tname, binary_g_nd3, itype, otype, op) \ instantiate_kernel("g3" #op #tname, binary_g_nd3, itype, otype, op) \
instantiate_kernel("g4" #op #tname, binary_g_nd, itype, otype, op, 4) \
instantiate_kernel("g5" #op #tname, binary_g_nd, itype, otype, op, 5)
#define instantiate_binary_integer(op) \ #define instantiate_binary_integer(op) \
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \ instantiate_binary_all(op, uint8, uint8_t, uint8_t) \

View File

@ -118,7 +118,7 @@ template <typename T, typename U, typename Op>
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2(index, a_strides); auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2(index, b_strides); auto b_idx = elem_to_loc_2(index, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * index.y; size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
auto out = Op()(a[a_idx], b[b_idx]); auto out = Op()(a[a_idx], b[b_idx]);
c[out_idx] = out[0]; c[out_idx] = out[0];
d[out_idx] = out[1]; d[out_idx] = out[1];
@ -137,31 +137,12 @@ template <typename T, typename U, typename Op>
auto a_idx = elem_to_loc_3(index, a_strides); auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides); auto b_idx = elem_to_loc_3(index, b_strides);
size_t out_idx = size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
auto out = Op()(a[a_idx], b[b_idx]); auto out = Op()(a[a_idx], b[b_idx]);
c[out_idx] = out[0]; c[out_idx] = out[0];
d[out_idx] = out[1]; d[out_idx] = out[1];
} }
template <typename T, typename U, typename Op, int DIM>
[[kernel]] void binary_g_nd(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant const int shape[DIM],
constant const size_t a_strides[DIM],
constant const size_t b_strides[DIM],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
auto out = Op()(a[idx.x], b[idx.y]);
c[out_idx] = out[0];
d[out_idx] = out[1];
}
template <typename T, typename U, typename Op> template <typename T, typename U, typename Op>
[[kernel]] void binary_g( [[kernel]] void binary_g(
device const T* a, device const T* a,
@ -175,7 +156,8 @@ template <typename T, typename U, typename Op>
uint3 index [[thread_position_in_grid]], uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) { uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim); auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z); size_t out_idx =
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
auto out = Op()(a[idx.x], b[idx.y]); auto out = Op()(a[idx.x], b[idx.y]);
c[out_idx] = out[0]; c[out_idx] = out[0];
d[out_idx] = out[1]; d[out_idx] = out[1];

View File

@ -19,8 +19,6 @@
instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op) \ instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op) \ instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3" #op #tname, binary_g_nd3, itype, otype, op) \ instantiate_kernel("g3" #op #tname, binary_g_nd3, itype, otype, op) \
instantiate_kernel("g4" #op #tname, binary_g_nd, itype, otype, op, 4) \
instantiate_kernel("g5" #op #tname, binary_g_nd, itype, otype, op, 5)
#define instantiate_binary_float(op) \ #define instantiate_binary_float(op) \
instantiate_binary_all(op, float16, half, half) \ instantiate_binary_all(op, float16, half, half) \

View File

@ -71,20 +71,6 @@ template <typename T, typename U>
dst[dst_idx] = static_cast<U>(src[src_idx]); dst[dst_idx] = static_cast<U>(src[src_idx]);
} }
template <typename T, typename U, int DIM>
[[kernel]] void copy_g_nd(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
int64_t dst_idx =
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U> template <typename T, typename U>
[[kernel]] void copy_g( [[kernel]] void copy_g(
device const T* src [[buffer(0)]], device const T* src [[buffer(0)]],
@ -136,19 +122,6 @@ template <typename T, typename U>
dst[dst_idx] = static_cast<U>(src[src_idx]); dst[dst_idx] = static_cast<U>(src[src_idx]);
} }
template <typename T, typename U, int DIM>
[[kernel]] void copy_gg_nd(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
auto dst_idx = elem_to_loc_nd<DIM>(index, src_shape, dst_strides);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U> template <typename T, typename U>
[[kernel]] void copy_gg( [[kernel]] void copy_gg(
device const T* src [[buffer(0)]], device const T* src [[buffer(0)]],

View File

@ -16,10 +16,6 @@
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype) \ instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype) \
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype) \ instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype) \
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype) \ instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype) \
instantiate_kernel("g4_copy" #tname, copy_g_nd, itype, otype, 4) \
instantiate_kernel("g5_copy" #tname, copy_g_nd, itype, otype, 5) \
instantiate_kernel("gg4_copy" #tname, copy_gg_nd, itype, otype, 4) \
instantiate_kernel("gg5_copy" #tname, copy_gg_nd, itype, otype, 5) \
instantiate_kernel("g_copy" #tname, copy_g, itype, otype) \ instantiate_kernel("g_copy" #tname, copy_g, itype, otype) \
instantiate_kernel("gg_copy" #tname, copy_gg, itype, otype) instantiate_kernel("gg_copy" #tname, copy_gg, itype, otype)

View File

@ -69,9 +69,9 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
device char* out, device char* out,
device const bool& odd, device const bool& odd,
device const uint& bytes_per_key, device const uint& bytes_per_key,
device const int& ndim, constant const int& ndim,
device const int* key_shape, constant const int* key_shape,
device const size_t* key_strides, constant const size_t* key_strides,
uint2 grid_dim [[threads_per_grid]], uint2 grid_dim [[threads_per_grid]],
uint2 index [[thread_position_in_grid]]) { uint2 index [[thread_position_in_grid]]) {
auto kidx = 2 * index.x; auto kidx = 2 * index.x;

View File

@ -342,9 +342,9 @@ template <
const constant int& in_stride_sorted_axis [[buffer(3)]], const constant int& in_stride_sorted_axis [[buffer(3)]],
const constant int& out_stride_sorted_axis [[buffer(4)]], const constant int& out_stride_sorted_axis [[buffer(4)]],
const constant int& nc_dim [[buffer(5)]], const constant int& nc_dim [[buffer(5)]],
const device int* nc_shape [[buffer(6)]], const constant int* nc_shape [[buffer(6)]],
const device size_t* in_nc_strides [[buffer(7)]], const constant size_t* in_nc_strides [[buffer(7)]],
const device size_t* out_nc_strides [[buffer(8)]], const constant size_t* out_nc_strides [[buffer(8)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) { uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel = using sort_kernel =
@ -485,8 +485,8 @@ template <
const constant int& size_sorted_axis [[buffer(3)]], const constant int& size_sorted_axis [[buffer(3)]],
const constant int& stride_sorted_axis [[buffer(4)]], const constant int& stride_sorted_axis [[buffer(4)]],
const constant int& nc_dim [[buffer(5)]], const constant int& nc_dim [[buffer(5)]],
const device int* nc_shape [[buffer(6)]], const constant int* nc_shape [[buffer(6)]],
const device size_t* nc_strides [[buffer(7)]], const constant size_t* nc_strides [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) { uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel = KernelMultiBlockMergeSort< using sort_kernel = KernelMultiBlockMergeSort<

View File

@ -52,7 +52,7 @@ template <typename T, typename Op>
auto a_idx = elem_to_loc_2(index, a_strides); auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2(index, b_strides); auto b_idx = elem_to_loc_2(index, b_strides);
auto c_idx = elem_to_loc_2(index, c_strides); auto c_idx = elem_to_loc_2(index, c_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * index.y; size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]); d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
} }
@ -71,29 +71,10 @@ template <typename T, typename Op>
auto b_idx = elem_to_loc_3(index, b_strides); auto b_idx = elem_to_loc_3(index, b_strides);
auto c_idx = elem_to_loc_3(index, c_strides); auto c_idx = elem_to_loc_3(index, c_strides);
size_t out_idx = size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]); d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
} }
template <typename T, typename Op, int DIM>
[[kernel]] void ternary_g_nd(
device const bool* a,
device const T* b,
device const T* c,
device T* d,
constant const int shape[DIM],
constant const size_t a_strides[DIM],
constant const size_t b_strides[DIM],
constant const size_t c_strides[DIM],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx =
elem_to_loc_3_nd<DIM>(index, shape, a_strides, b_strides, c_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
}
template <typename T, typename Op> template <typename T, typename Op>
[[kernel]] void ternary_g( [[kernel]] void ternary_g(
device const bool* a, device const bool* a,
@ -109,6 +90,7 @@ template <typename T, typename Op>
uint3 grid_dim [[threads_per_grid]]) { uint3 grid_dim [[threads_per_grid]]) {
auto idx = auto idx =
elem_to_loc_3_nd(index, shape, a_strides, b_strides, c_strides, ndim); elem_to_loc_3_nd(index, shape, a_strides, b_strides, c_strides, ndim);
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z); size_t out_idx =
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]); d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
} }

View File

@ -16,8 +16,6 @@
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op) \ instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op) \
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op) \ instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op) \
instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op) \ instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op) \
instantiate_kernel("g4_" #op #tname, ternary_g_nd, type, op, 4) \
instantiate_kernel("g5_" #op #tname, ternary_g_nd, type, op, 5)
#define instantiate_ternary_types(op) \ #define instantiate_ternary_types(op) \
instantiate_ternary_all(op, bool_, bool) \ instantiate_ternary_all(op, bool_, bool) \

View File

@ -22,8 +22,8 @@ template <typename T, typename Op>
[[kernel]] void unary_g( [[kernel]] void unary_g(
device const T* in, device const T* in,
device T* out, device T* out,
device const int* in_shape, constant const int* in_shape,
device const size_t* in_strides, constant const size_t* in_strides,
device const int& ndim, device const int& ndim,
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
auto idx = elem_to_loc(index, in_shape, in_strides, ndim); auto idx = elem_to_loc(index, in_shape, in_strides, ndim);

View File

@ -83,20 +83,6 @@ struct Limits<complex64_t> {
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Single Array with generic dims // Single Array with generic dims
template <typename stride_t>
METAL_FUNC stride_t elem_to_loc(
uint elem,
device const int* shape,
device const stride_t* strides,
int ndim) {
stride_t loc = 0;
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
loc += (elem % shape[i]) * strides[i];
elem /= shape[i];
}
return loc;
}
template <typename stride_t> template <typename stride_t>
METAL_FUNC stride_t elem_to_loc( METAL_FUNC stride_t elem_to_loc(
uint elem, uint elem,
@ -111,20 +97,6 @@ METAL_FUNC stride_t elem_to_loc(
return loc; return loc;
} }
template <typename stride_t>
METAL_FUNC stride_t elem_to_loc(
stride_t elem,
device const int* shape,
device const stride_t* strides,
int ndim) {
stride_t loc = 0;
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
loc += (elem % shape[i]) * strides[i];
elem /= shape[i];
}
return loc;
}
template <typename stride_t> template <typename stride_t>
METAL_FUNC stride_t elem_to_loc( METAL_FUNC stride_t elem_to_loc(
stride_t elem, stride_t elem,
@ -174,78 +146,18 @@ elem_to_loc_3(uint3 elem, constant const stride_t strides[3]) {
return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0]; return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
} }
template <int NDIM>
METAL_FUNC size_t elem_to_loc_nd(
uint elem,
device const int* shape,
device const size_t* strides) {
size_t loc = (elem % shape[NDIM - 1]) * strides[NDIM - 1];
MLX_MTL_PRAGMA_UNROLL
for (int d = NDIM - 2; d >= 0; --d) {
elem /= shape[d + 1];
loc += (elem % shape[d]) * strides[d];
}
return loc;
}
template <int NDIM>
METAL_FUNC size_t elem_to_loc_nd(
uint3 elem,
constant const int shape[NDIM],
constant const size_t strides[NDIM]) {
size_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
for (int d = NDIM - 3; d >= 0; --d) {
loc += (elem.z % shape[d]) * strides[d];
elem.z /= shape[d];
}
return loc;
}
template <int NDIM>
METAL_FUNC int64_t elem_to_loc_nd(
uint elem,
constant const int shape[NDIM],
constant const int64_t strides[NDIM]) {
int64_t loc = (elem % shape[NDIM - 1]) * strides[NDIM - 1];
MLX_MTL_PRAGMA_UNROLL
for (int d = NDIM - 2; d >= 0; --d) {
elem /= shape[d + 1];
loc += (elem % shape[d]) * strides[d];
}
return loc;
}
template <int NDIM>
METAL_FUNC int64_t elem_to_loc_nd(
uint3 elem,
constant const int shape[NDIM],
constant const int64_t strides[NDIM]) {
int64_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
for (int d = NDIM - 3; d >= 0; --d) {
loc += (elem.z % shape[d]) * strides[d];
elem.z /= shape[d];
}
return loc;
}
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Multiple Arrays with generic dims // Multiple Arrays with generic dims
METAL_FUNC uint2 elem_to_loc_2_nd( METAL_FUNC ulong2 elem_to_loc_2_nd(
uint3 elem, uint3 elem,
constant const int* shape, constant const int* shape,
constant const size_t* a_strides, constant const size_t* a_strides,
constant const size_t* b_strides, constant const size_t* b_strides,
int ndim) { int ndim) {
uint2 loc = { ulong2 loc = {
static_cast<uint>( elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2],
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]), elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2]};
static_cast<uint>(
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
for (int d = ndim - 3; d >= 0; --d) { for (int d = ndim - 3; d >= 0; --d) {
uint l = elem.z % shape[d]; uint l = elem.z % shape[d];
loc.x += l * a_strides[d]; loc.x += l * a_strides[d];
@ -255,20 +167,17 @@ METAL_FUNC uint2 elem_to_loc_2_nd(
return loc; return loc;
} }
METAL_FUNC uint3 elem_to_loc_3_nd( METAL_FUNC ulong3 elem_to_loc_3_nd(
uint3 elem, uint3 elem,
constant const int* shape, constant const int* shape,
constant const size_t* a_strides, constant const size_t* a_strides,
constant const size_t* b_strides, constant const size_t* b_strides,
constant const size_t* c_strides, constant const size_t* c_strides,
int ndim) { int ndim) {
uint3 loc = { ulong3 loc = {
static_cast<uint>( elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2],
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]), elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2],
static_cast<uint>( elem.x * c_strides[ndim - 1] + elem.y * c_strides[ndim - 2]};
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2]),
static_cast<uint>(
elem.x * c_strides[ndim - 1] + elem.y * c_strides[ndim - 2])};
for (int d = ndim - 3; d >= 0; --d) { for (int d = ndim - 3; d >= 0; --d) {
uint l = elem.z % shape[d]; uint l = elem.z % shape[d];
loc.x += l * a_strides[d]; loc.x += l * a_strides[d];
@ -279,53 +188,6 @@ METAL_FUNC uint3 elem_to_loc_3_nd(
return loc; return loc;
} }
///////////////////////////////////////////////////////////////////////////////
// Multiple Arrays with fixed N dims
template <int NDIM>
METAL_FUNC uint2 elem_to_loc_2_nd(
uint3 elem,
constant const int shape[NDIM],
constant const size_t a_strides[NDIM],
constant const size_t b_strides[NDIM]) {
uint2 loc = {
static_cast<uint>(
elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
static_cast<uint>(
elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2])};
for (int d = NDIM - 3; d >= 0; --d) {
uint l = elem.z % shape[d];
loc.x += l * a_strides[d];
loc.y += l * b_strides[d];
elem.z /= shape[d];
}
return loc;
}
template <int NDIM>
METAL_FUNC uint3 elem_to_loc_3_nd(
uint3 elem,
constant const int shape[NDIM],
constant const size_t a_strides[NDIM],
constant const size_t b_strides[NDIM],
constant const size_t c_strides[NDIM]) {
uint3 loc = {
static_cast<uint>(
elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
static_cast<uint>(
elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2]),
static_cast<uint>(
elem.x * c_strides[NDIM - 1] + elem.y * c_strides[NDIM - 2])};
for (int d = NDIM - 3; d >= 0; --d) {
uint l = elem.z % shape[d];
loc.x += l * a_strides[d];
loc.y += l * b_strides[d];
loc.z += l * c_strides[d];
elem.z /= shape[d];
}
return loc;
}
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Elem to loc in a loop utils // Elem to loc in a loop utils
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////

View File

@ -8,7 +8,7 @@
namespace mlx::core { namespace mlx::core {
constexpr int MAX_TERNARY_SPECIALIZED_DIMS = 5; constexpr int MAX_TERNARY_SPECIALIZED_DIMS = 3;
void ternary_op_gpu_inplace( void ternary_op_gpu_inplace(
const std::vector<array>& inputs, const std::vector<array>& inputs,
@ -26,11 +26,21 @@ void ternary_op_gpu_inplace(
} }
// Try to collapse contiguous dims // Try to collapse contiguous dims
auto [shape, strides] = collapse_contiguous_dims(a, b, c, out); auto maybe_collapse = [topt, &a, &b, &c, &out]() {
auto& strides_a = strides[0]; if (topt == TernaryOpType::General) {
auto& strides_b = strides[1]; // The size cap here should ideally be `UINT32_MAX` but we are
auto& strides_c = strides[2]; // limitied by the shape being an int.
auto& strides_out = strides[3]; auto [shape, strides] = collapse_contiguous_dims(
{a, b, c, out},
/* size_cap = */ INT32_MAX);
return std::make_tuple(
shape, strides[0], strides[1], strides[2], strides[3]);
} else {
std::vector<size_t> e;
return std::make_tuple(std::vector<int>{}, e, e, e, e);
}
};
auto [shape, strides_a, strides_b, strides_c, strides_out] = maybe_collapse();
bool use_2d = out.data_size() > UINT_MAX; bool use_2d = out.data_size() > UINT_MAX;
std::string kernel_name; std::string kernel_name;
@ -88,7 +98,7 @@ void ternary_op_gpu_inplace(
size_t rest = out.size() / (dim0 * dim1); size_t rest = out.size() / (dim0 * dim1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) { if (thread_group_size != 1024) {
throw std::runtime_error("[Metal::binary] Must use 1024 sized block"); throw std::runtime_error("[Metal::ternary] Must use 1024 sized block");
} }
MTL::Size group_dims = get_block_dims(dim0, dim1, rest); MTL::Size group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);

View File

@ -6,4 +6,5 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/../common/compiled.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../common/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../common/compiled_nocpu.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../common/compiled_nocpu.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../common/reduce_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../common/reduce_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../common/slicing.cpp) ${CMAKE_CURRENT_SOURCE_DIR}/../common/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../common/utils.cpp)

View File

@ -243,7 +243,7 @@ void init_fast(nb::module_& parent_module) {
template_args.emplace_back(name, dtype); template_args.emplace_back(name, dtype);
} else { } else {
throw std::invalid_argument( throw std::invalid_argument(
"[[metal_kernel]] Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`."); "[metal_kernel] Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`.");
} }
} }
} }
@ -271,25 +271,24 @@ void init_fast(nb::module_& parent_module) {
nb::sig( nb::sig(
"def __call__(self, *, inputs: List[Union[scalar, array]], output_shapes: List[Sequence[int]], output_dtypes: List[Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[List[Tuple[str, Union[bool, int, Dtype]]]] = None, init_value: Optional[float] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"), "def __call__(self, *, inputs: List[Union[scalar, array]], output_shapes: List[Sequence[int]], output_dtypes: List[Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[List[Tuple[str, Union[bool, int, Dtype]]]] = None, init_value: Optional[float] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"),
R"pbdoc( R"pbdoc(
Run the kernel. Run the kernel.
Args: Args:
inputs (List[array]): The inputs passed to the Metal kernel. inputs (List[array]): The inputs passed to the Metal kernel.
output_shapes (List[Sequence[int]]): The list of shapes for each output in ``output_names``. output_shapes (List[Sequence[int]]): The list of shapes for each output in ``output_names``.
output_dtypes (List[Dtype]): The list of data types for each output in ``output_names``. output_dtypes (List[Dtype]): The list of data types for each output in ``output_names``.
grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with. grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.
threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use. threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.
template (List[Tuple[str, Union[bool, int, Dtype]]], optional): Template arguments. template (List[Tuple[str, Union[bool, int, Dtype]]], optional): Template arguments.
These will be added as template arguments to the kernel definition. Default: ``None``. These will be added as template arguments to the kernel definition. Default: ``None``.
init_value (float, optional): Optional value to use to initialize all of the output arrays. init_value (float, optional): Optional value to use to initialize all of the output arrays.
By default, output arrays are uninitialized. Default: ``None``. By default, output arrays are uninitialized. Default: ``None``.
verbose (bool, optional): Whether to print the full generated source code of the kernel verbose (bool, optional): Whether to print the full generated source code of the kernel
when it is run. Default: ``False``. when it is run. Default: ``False``.
stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``. stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``.
Returns: Returns:
List[array]: The list of output arrays. List[array]: The list of output arrays.)pbdoc");
)pbdoc");
}, },
"name"_a, "name"_a,
"input_names"_a, "input_names"_a,
@ -306,16 +305,16 @@ void init_fast(nb::module_& parent_module) {
input_names (List[str]): The parameter names of the inputs in the input_names (List[str]): The parameter names of the inputs in the
function signature. function signature.
output_names (List[str]): The parameter names of the outputs in the output_names (List[str]): The parameter names of the outputs in the
function signature. function signature.
source (str): Source code. This is the body of a function in Metal, source (str): Source code. This is the body of a function in Metal,
the function signature will be automatically generated. the function signature will be automatically generated.
header (str): Header source code to include before the main function. header (str): Header source code to include before the main function.
Useful for helper functions or includes that should live outside of Useful for helper functions or includes that should live outside of
the main function body. the main function body.
ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous
before the kernel runs. Default: ``True``. before the kernel runs. Default: ``True``.
atomic_outputs (bool): Whether to use atomic outputs in the function signature atomic_outputs (bool): Whether to use atomic outputs in the function signature
e.g. ``device atomic<float>``. Default: ``False``. e.g. ``device atomic<float>``. Default: ``False``.
Returns: Returns:
Callable ``metal_kernel``. Callable ``metal_kernel``.