From 4f46e9c99793b37bac1905047893892e643c59b5 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 17 Sep 2024 12:46:31 -0700 Subject: [PATCH] More fixes for arrays with large sizes (#1405) * compile works for big arrays when contiguous * style * nits in docs * a bunch more stuff * update jit * update jit * use constant for shapes and strides and remove elem_to_loc overload * use kernel instantiation * docs nits * update binary and ternary * comments --- docs/src/python/nn/functions.rst | 1 + docs/src/python/nn/layers.rst | 5 + mlx/backend/common/CMakeLists.txt | 1 + mlx/backend/common/utils.cpp | 88 ++++++++++++ mlx/backend/common/utils.h | 56 ++------ mlx/backend/metal/binary.cpp | 132 ++++++----------- mlx/backend/metal/compiled.cpp | 45 +++++- mlx/backend/metal/copy.cpp | 20 ++- mlx/backend/metal/jit/copy.h | 100 ------------- mlx/backend/metal/jit_kernels.cpp | 88 ++++++------ mlx/backend/metal/kernels/arg_reduce.metal | 38 ++--- mlx/backend/metal/kernels/binary.h | 23 +-- mlx/backend/metal/kernels/binary.metal | 2 - mlx/backend/metal/kernels/binary_two.h | 26 +--- mlx/backend/metal/kernels/binary_two.metal | 2 - mlx/backend/metal/kernels/copy.h | 27 ---- mlx/backend/metal/kernels/copy.metal | 4 - mlx/backend/metal/kernels/random.metal | 6 +- mlx/backend/metal/kernels/sort.h | 10 +- mlx/backend/metal/kernels/ternary.h | 26 +--- mlx/backend/metal/kernels/ternary.metal | 2 - mlx/backend/metal/kernels/unary.h | 4 +- mlx/backend/metal/kernels/utils.h | 156 ++------------------- mlx/backend/metal/ternary.cpp | 24 +++- mlx/backend/no_cpu/CMakeLists.txt | 3 +- python/src/fast.cpp | 47 +++---- 26 files changed, 325 insertions(+), 611 deletions(-) create mode 100644 mlx/backend/common/utils.cpp delete mode 100644 mlx/backend/metal/jit/copy.h diff --git a/docs/src/python/nn/functions.rst b/docs/src/python/nn/functions.rst index f1077776b..9b6cd9f62 100644 --- a/docs/src/python/nn/functions.rst +++ b/docs/src/python/nn/functions.rst @@ -13,6 +13,7 @@ simple functions. :template: nn-module-template.rst elu + celu gelu gelu_approx gelu_fast_approx diff --git a/docs/src/python/nn/layers.rst b/docs/src/python/nn/layers.rst index 77105ea35..fc24d410b 100644 --- a/docs/src/python/nn/layers.rst +++ b/docs/src/python/nn/layers.rst @@ -13,6 +13,7 @@ Layers AvgPool1d AvgPool2d BatchNorm + CELU Conv1d Conv2d Conv3d @@ -23,6 +24,7 @@ Layers Dropout2d Dropout3d Embedding + ELU GELU GLU GroupNorm @@ -34,6 +36,8 @@ Layers LayerNorm LeakyReLU Linear + LogSigmoid + LogSoftmax LSTM MaxPool1d MaxPool2d @@ -49,6 +53,7 @@ Layers RoPE SELU Sequential + Sigmoid SiLU SinusoidalPositionalEncoding Softmin diff --git a/mlx/backend/common/CMakeLists.txt b/mlx/backend/common/CMakeLists.txt index 56343ada4..925f4731c 100644 --- a/mlx/backend/common/CMakeLists.txt +++ b/mlx/backend/common/CMakeLists.txt @@ -51,6 +51,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp ${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp) if(IOS) diff --git a/mlx/backend/common/utils.cpp b/mlx/backend/common/utils.cpp new file mode 100644 index 000000000..30e743a79 --- /dev/null +++ b/mlx/backend/common/utils.cpp @@ -0,0 +1,88 @@ +// Copyright © 2023-2024 Apple Inc. + +#include "mlx/backend/common/utils.h" + +namespace mlx::core { + +template +std::tuple, std::vector>> +collapse_contiguous_dims_impl( + const std::vector& shape, + const std::vector>& strides, + stride_t size_cap) { + // Make a vector that has axes separated with -1. Collapse all axes between + // -1. + std::vector to_collapse; + if (shape.size() > 0) { + if (shape[0] != 1) { + to_collapse.push_back(0); + } + size_t size = shape[0]; + for (int i = 1; i < shape.size(); i++) { + bool contiguous = true; + size *= shape[i]; + for (const std::vector& st : strides) { + if (st[i] * shape[i] != st[i - 1] || size > size_cap) { + contiguous = false; + size = shape[i]; + break; + } + } + if (!contiguous) { + to_collapse.push_back(-1); + } + if (shape[i] != 1) { + to_collapse.push_back(i); + } + } + to_collapse.push_back(-1); + } + + std::vector out_shape; + std::vector> out_strides(strides.size()); + for (int i = 0;;) { + while (i < to_collapse.size() && to_collapse[i] == -1) { + ++i; + }; + if (i == to_collapse.size()) { + break; + } + int current_shape = shape[to_collapse[i]]; + int k = i; + while (to_collapse[++k] != -1) { + current_shape *= shape[to_collapse[k]]; + } + out_shape.push_back(current_shape); + for (int j = 0; j < strides.size(); j++) { + const std::vector& st = strides[j]; + out_strides[j].push_back(st[to_collapse[k - 1]]); + } + i = k + 1; + } + + if (!shape.empty() && out_shape.empty()) { + out_shape.push_back(1); + for (auto& out_stride : out_strides) { + out_stride.push_back(0); + } + } + return std::make_tuple(out_shape, out_strides); +} + +std::tuple, std::vector>> +collapse_contiguous_dims( + const std::vector& shape, + const std::vector>& strides, + int64_t size_cap /* = std::numeric_limits::max() */) { + return collapse_contiguous_dims_impl(shape, strides, size_cap); +} + +std::tuple, std::vector>> +collapse_contiguous_dims( + const std::vector& shape, + const std::vector>& strides, + size_t size_cap /* = std::numeric_limits::max() */) { + return collapse_contiguous_dims_impl(shape, strides, size_cap); +} + +} // namespace mlx::core diff --git a/mlx/backend/common/utils.h b/mlx/backend/common/utils.h index 3c4a3d0a3..6f57fe11b 100644 --- a/mlx/backend/common/utils.h +++ b/mlx/backend/common/utils.h @@ -44,58 +44,26 @@ std::vector make_contiguous_strides(const std::vector& shape) { // // When multiple arrays are passed they should all have the same shape. The // collapsed axes are also the same so one shape is returned. -template -inline std::tuple, std::vector>> +std::tuple, std::vector>> collapse_contiguous_dims( const std::vector& shape, - const std::vector> strides) { - // Make a vector that has axes separated with -1. Collapse all axes between - // -1. - std::vector to_collapse; - if (shape.size() > 0) { - to_collapse.push_back(0); - for (int i = 1; i < shape.size(); i++) { - bool contiguous = true; - for (const std::vector& st : strides) { - if (st[i] * shape[i] != st[i - 1]) { - contiguous = false; - } - if (!contiguous) { - break; - } - } - if (!contiguous) { - to_collapse.push_back(-1); - } - to_collapse.push_back(i); - } - to_collapse.push_back(-1); - } - - std::vector out_shape; - std::vector> out_strides(strides.size()); - for (int i = 0; i < to_collapse.size(); i++) { - int current_shape = shape[to_collapse[i]]; - while (to_collapse[++i] != -1) { - current_shape *= shape[to_collapse[i]]; - } - out_shape.push_back(current_shape); - for (int j = 0; j < strides.size(); j++) { - const std::vector& st = strides[j]; - out_strides[j].push_back(st[to_collapse[i - 1]]); - } - } - - return std::make_tuple(out_shape, out_strides); -} + const std::vector>& strides, + int64_t size_cap = std::numeric_limits::max()); +std::tuple, std::vector>> +collapse_contiguous_dims( + const std::vector& shape, + const std::vector>& strides, + size_t size_cap = std::numeric_limits::max()); inline std::tuple, std::vector>> -collapse_contiguous_dims(const std::vector& xs) { +collapse_contiguous_dims( + const std::vector& xs, + size_t size_cap = std::numeric_limits::max()) { std::vector> strides; for (auto& x : xs) { strides.emplace_back(x.strides()); } - return collapse_contiguous_dims(xs[0].shape(), strides); + return collapse_contiguous_dims(xs[0].shape(), strides, size_cap); } template > diff --git a/mlx/backend/metal/binary.cpp b/mlx/backend/metal/binary.cpp index 1c0dac55e..59a661fc8 100644 --- a/mlx/backend/metal/binary.cpp +++ b/mlx/backend/metal/binary.cpp @@ -19,7 +19,7 @@ namespace mlx::core { -constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5; +constexpr int MAX_BINARY_SPECIALIZED_DIMS = 3; std::string get_kernel_name( BinaryOpType bopt, @@ -69,46 +69,61 @@ void binary_op_gpu_inplace( } // Try to collapse contiguous dims - auto [shape, strides] = collapse_contiguous_dims(a, b, out); - auto& strides_a = strides[0]; - auto& strides_b = strides[1]; - auto& strides_out = strides[2]; + auto maybe_collapse = [bopt, &a, &b, &out]() { + if (bopt == BinaryOpType::General) { + // The size cap here should ideally be `UINT32_MAX` but we are + // limitied by the shape being an int. + auto [shape, strides] = collapse_contiguous_dims( + {a, b, out}, + /* size_cap = */ INT32_MAX); + return std::make_tuple(shape, strides[0], strides[1], strides[2]); + } else { + std::vector e; + return std::make_tuple(std::vector{}, e, e, e); + } + }; + auto [shape, strides_a, strides_b, strides_out] = maybe_collapse(); bool use_2d = out.data_size() > UINT32_MAX; std::string kernel_name = get_kernel_name(bopt, op, a, use_2d, shape.size()); auto& d = metal::device(s.device); - auto kernel = - get_binary_two_kernel(d, kernel_name, a.dtype(), outputs[0].dtype(), op); - + auto kernel = outputs.size() == 2 + ? get_binary_two_kernel(d, kernel_name, a.dtype(), out.dtype(), op) + : get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder->setComputePipelineState(kernel); // - If a is donated it goes to the first output // - If b is donated it goes to the first output if a was not donated - // otherwise it goes to the second output + // otherwise it goes to the second output. + // - If there is only one output only one of a and b will be donated. bool donate_a = a.data_shared_ptr() == nullptr; bool donate_b = b.data_shared_ptr() == nullptr; - compute_encoder.set_input_array(donate_a ? outputs[0] : a, 0); + int arg_idx = 0; + compute_encoder.set_input_array(donate_a ? outputs[0] : a, arg_idx++); compute_encoder.set_input_array( - donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, 1); - compute_encoder.set_output_array(outputs[0], 2); - compute_encoder.set_output_array(outputs[1], 3); + donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, arg_idx++); + compute_encoder.set_output_array(outputs[0], arg_idx++); + if (outputs.size() == 2) { + compute_encoder.set_output_array(outputs[1], arg_idx++); + } if (bopt == BinaryOpType::General) { auto ndim = shape.size(); if (ndim > 3) { - compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4); - compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5); - compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6); + compute_encoder->setBytes(shape.data(), ndim * sizeof(int), arg_idx++); + compute_encoder->setBytes( + strides_a.data(), ndim * sizeof(size_t), arg_idx++); + compute_encoder->setBytes( + strides_b.data(), ndim * sizeof(size_t), arg_idx++); + compute_encoder->setBytes(&ndim, sizeof(int), arg_idx++); } else { // The shape is implicit in the grid for <= 3D - compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4); - compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5); - } - - if (ndim > MAX_BINARY_SPECIALIZED_DIMS) { - compute_encoder->setBytes(&ndim, sizeof(int), 7); + compute_encoder->setBytes( + strides_a.data(), ndim * sizeof(size_t), arg_idx++); + compute_encoder->setBytes( + strides_b.data(), ndim * sizeof(size_t), arg_idx++); } // Launch up to 3D grid of threads @@ -125,9 +140,8 @@ void binary_op_gpu_inplace( } else { // Launch a 1D or 2D grid of threads size_t nthreads = out.data_size(); - MTL::Size grid_dims = use_2d - ? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides()) - : MTL::Size(nthreads, 1, 1); + MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides()) + : MTL::Size(nthreads, 1, 1); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); if (thread_group_size > nthreads) { thread_group_size = nthreads; @@ -164,72 +178,8 @@ void binary_op_gpu_inplace( array& out, const std::string& op, const Stream& s) { - auto& a = inputs[0]; - auto& b = inputs[1]; - auto bopt = get_binary_op_type(a, b); - if (out.size() == 0) { - return; - } - - // Try to collapse contiguous dims - auto [shape, strides] = collapse_contiguous_dims(a, b, out); - auto& strides_a = strides[0]; - auto& strides_b = strides[1]; - auto& strides_out = strides[2]; - - bool use_2d = out.data_size() > UINT32_MAX; - std::string kernel_name = get_kernel_name(bopt, op, a, use_2d, shape.size()); - auto& d = metal::device(s.device); - - auto kernel = get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op); - auto& compute_encoder = d.get_command_encoder(s.index); - compute_encoder->setComputePipelineState(kernel); - bool donate_a = a.data_shared_ptr() == nullptr; - bool donate_b = b.data_shared_ptr() == nullptr; - compute_encoder.set_input_array(donate_a ? out : a, 0); - compute_encoder.set_input_array(donate_b ? out : b, 1); - compute_encoder.set_output_array(out, 2); - - if (bopt == BinaryOpType::General) { - auto ndim = shape.size(); - if (ndim > 3) { - compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 3); - compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4); - compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5); - } else { - // The shape is implicit in the grid for <= 3D - compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 3); - compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 4); - } - - if (ndim > MAX_BINARY_SPECIALIZED_DIMS) { - compute_encoder->setBytes(&ndim, sizeof(int), 6); - } - - // Launch up to 3D grid of threads - size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; - size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; - size_t rest = out.size() / (dim0 * dim1); - NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); - if (thread_group_size != 1024) { - throw std::runtime_error("[Metal::binary] Must use 1024 sized block"); - } - auto group_dims = get_block_dims(dim0, dim1, rest); - MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); - compute_encoder.dispatchThreads(grid_dims, group_dims); - } else { - // Launch a 1D or 2D grid of threads - - size_t nthreads = out.data_size(); - MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides()) - : MTL::Size(nthreads, 1, 1); - NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); - if (thread_group_size > nthreads) { - thread_group_size = nthreads; - } - MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - compute_encoder.dispatchThreads(grid_dims, group_dims); - } + std::vector outputs = {out}; + binary_op_gpu_inplace(inputs, outputs, op, s); } void binary_op_gpu( diff --git a/mlx/backend/metal/compiled.cpp b/mlx/backend/metal/compiled.cpp index b6bb055f2..e90658650 100644 --- a/mlx/backend/metal/compiled.cpp +++ b/mlx/backend/metal/compiled.cpp @@ -22,7 +22,8 @@ inline void build_kernel( const std::unordered_set& constant_ids, bool contiguous, int ndim, - bool dynamic_dims) { + bool dynamic_dims, + bool use_big_index = false) { // All outputs should have the exact same shape and will be row contiguous auto output_shape = outputs[0].shape(); auto output_strides = outputs[0].strides(); @@ -84,9 +85,15 @@ inline void build_kernel( // The thread index in the whole grid os << " uint3 pos [[thread_position_in_grid]]," << std::endl - << " uint3 grid [[threads_per_grid]]) {" << std::endl - << " uint index = pos.x + grid.x * (pos.y + grid.y * pos.z);" - << std::endl; + << " uint3 grid [[threads_per_grid]]) {" << std::endl; + if (use_big_index) { + // This is only used for contiguous kernels which don't have + // a third grid dimension + os << " size_t index = pos.x + grid.x * size_t(pos.y);"; + } else { + os << " uint index = pos.x + grid.x * (pos.y + grid.y * pos.z);"; + } + os << std::endl; // Extract the indices per axis to individual uints if we have arrays that // are broadcasted or transposed @@ -212,6 +219,17 @@ void Compiled::eval_gpu( /* contiguous = */ true, /* ndim = */ 0, /* dynamic_dims = */ false); + build_kernel( + kernel, + kernel_lib_ + "_contiguous_big", + inputs_, + outputs_, + tape_, + constant_ids_, + /* contiguous = */ true, + /* ndim = */ 0, + /* dynamic_dims = */ false, + /* use_big_index = */ true); for (int i = 1; i < 8; i++) { build_kernel( kernel, @@ -285,7 +303,16 @@ void Compiled::eval_gpu( initial_strides.push_back(std::move(xstrides)); } std::tie(shape, strides) = - collapse_contiguous_dims(output_shape, initial_strides); + collapse_contiguous_dims(output_shape, initial_strides, INT32_MAX); + } + + bool use_2d = false; + if (contiguous) { + size_t max_size = 0; + for (auto& in : inputs) { + max_size = std::max(max_size, in.data_size()); + } + use_2d = (max_size > UINT32_MAX); } // Get the kernel from the lib @@ -298,6 +325,8 @@ void Compiled::eval_gpu( } else { kernel_name += std::to_string(shape.size()); } + } else if (use_2d) { + kernel_name += "_big"; } auto kernel = d.get_kernel(kernel_name, lib); auto& compute_encoder = d.get_command_encoder(s.index); @@ -348,8 +377,10 @@ void Compiled::eval_gpu( // Launch the kernel if (contiguous) { - size_t nthreads = outputs[0].size(); - MTL::Size grid_dims(nthreads, 1, 1); + size_t nthreads = outputs[0].data_size(); + MTL::Size grid_dims = use_2d + ? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides()) + : MTL::Size(nthreads, 1, 1); MTL::Size group_dims( std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1); compute_encoder.dispatchThreads(grid_dims, group_dims); diff --git a/mlx/backend/metal/copy.cpp b/mlx/backend/metal/copy.cpp index 0fc99220d..19596f6b4 100644 --- a/mlx/backend/metal/copy.cpp +++ b/mlx/backend/metal/copy.cpp @@ -10,7 +10,7 @@ namespace mlx::core { -constexpr int MAX_COPY_SPECIALIZED_DIMS = 5; +constexpr int MAX_COPY_SPECIALIZED_DIMS = 3; void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) { if (ctype == CopyType::Vector) { @@ -59,10 +59,20 @@ void copy_gpu_inplace( } // Try to collapse contiguous dims - auto [shape, strides] = collapse_contiguous_dims( - data_shape, std::vector{strides_in_pre, strides_out_pre}); - auto& strides_in_ = strides[0]; - auto& strides_out_ = strides[1]; + auto maybe_collapse = + [ctype, &data_shape, &strides_in_pre, &strides_out_pre]() { + if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) { + auto [shape, strides] = collapse_contiguous_dims( + data_shape, + std::vector{strides_in_pre, strides_out_pre}, + /* size_cap = */ INT32_MAX); + return std::make_tuple(shape, strides[0], strides[1]); + } else { + std::vector e; + return std::make_tuple(std::vector{}, e, e); + } + }; + auto [shape, strides_in_, strides_out_] = maybe_collapse(); bool use_2d = out.data_size() > UINT32_MAX; auto& d = metal::device(s.device); diff --git a/mlx/backend/metal/jit/copy.h b/mlx/backend/metal/jit/copy.h deleted file mode 100644 index 167be8f84..000000000 --- a/mlx/backend/metal/jit/copy.h +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright © 2024 Apple Inc. - -constexpr std::string_view copy_kernels = R"( -template [[host_name("s_{0}")]] [[kernel]] void copy_s<{1}, {2}>( - device const {1}* src [[buffer(0)]], - device {2}* dst [[buffer(1)]], - uint index [[thread_position_in_grid]]); -template [[host_name("v_{0}")]] [[kernel]] void copy_v<{1}, {2}>( - device const {1}* src [[buffer(0)]], - device {2}* dst [[buffer(1)]], - uint index [[thread_position_in_grid]]); - -template [[host_name("g4_{0}")]] [[kernel]] void -copy_g_nd<{1}, {2}, 4>( - device const {1}* src [[buffer(0)]], - device {2}* dst [[buffer(1)]], - constant const int* src_shape [[buffer(2)]], - constant const int64_t* src_strides [[buffer(3)]], - uint3 index [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]); -template [[host_name("gg4_{0}")]] [[kernel]] void -copy_gg_nd<{1}, {2}, 4>( - device const {1}* src [[buffer(0)]], - device {2}* dst [[buffer(1)]], - constant const int* src_shape [[buffer(2)]], - constant const int64_t* src_strides [[buffer(3)]], - constant const int64_t* dst_strides [[buffer(4)]], - uint3 index [[thread_position_in_grid]]); -template [[host_name("g5_{0}")]] [[kernel]] void -copy_g_nd<{1}, {2}, 5>( - device const {1}* src [[buffer(0)]], - device {2}* dst [[buffer(1)]], - constant const int* src_shape [[buffer(2)]], - constant const int64_t* src_strides [[buffer(3)]], - uint3 index [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]); -template [[host_name("gg5_{0}")]] [[kernel]] void -copy_gg_nd<{1}, {2}, 5>( - device const {1}* src [[buffer(0)]], - device {2}* dst [[buffer(1)]], - constant const int* src_shape [[buffer(2)]], - constant const int64_t* src_strides [[buffer(3)]], - constant const int64_t* dst_strides [[buffer(4)]], - uint3 index [[thread_position_in_grid]]); -template [[host_name("g1_{0}")]] [[kernel]] void copy_g_nd1<{1}, {2}>( - device const {1}* src [[buffer(0)]], - device {2}* dst [[buffer(1)]], - constant const int64_t& src_stride [[buffer(3)]], - uint index [[thread_position_in_grid]]); -template [[host_name("g2_{0}")]] [[kernel]] void copy_g_nd2<{1}, {2}>( - device const {1}* src [[buffer(0)]], - device {2}* dst [[buffer(1)]], - constant const int64_t* src_strides [[buffer(3)]], - uint2 index [[thread_position_in_grid]], - uint2 grid_dim [[threads_per_grid]]); -template [[host_name("g3_{0}")]] [[kernel]] void copy_g_nd3<{1}, {2}>( - device const {1}* src [[buffer(0)]], - device {2}* dst [[buffer(1)]], - constant const int64_t* src_strides [[buffer(3)]], - uint3 index [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]); -template [[host_name("gg1_{0}")]] [[kernel]] void -copy_gg_nd1<{1}, {2}>( - device const {1}* src [[buffer(0)]], - device {2}* dst [[buffer(1)]], - constant const int64_t& src_stride [[buffer(3)]], - constant const int64_t& dst_stride [[buffer(4)]], - uint index [[thread_position_in_grid]]); -template [[host_name("gg2_{0}")]] [[kernel]] void -copy_gg_nd2<{1}, {2}>( - device const {1}* src [[buffer(0)]], - device {2}* dst [[buffer(1)]], - constant const int64_t* src_strides [[buffer(3)]], - constant const int64_t* dst_strides [[buffer(4)]], - uint2 index [[thread_position_in_grid]]); -template [[host_name("gg3_{0}")]] [[kernel]] void -copy_gg_nd3<{1}, {2}>( - device const {1}* src [[buffer(0)]], - device {2}* dst [[buffer(1)]], - constant const int64_t* src_strides [[buffer(3)]], - constant const int64_t* dst_strides [[buffer(4)]], - uint3 index [[thread_position_in_grid]]); - -template [[host_name("g_{0}")]] [[kernel]] void copy_g<{1}, {2}>( - device const {1}* src [[buffer(0)]], - device {2}* dst [[buffer(1)]], - constant const int* src_shape [[buffer(2)]], - constant const int64_t* src_strides [[buffer(3)]], - constant const int& ndim [[buffer(5)]], - uint3 index [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]); -template [[host_name("gg_{0}")]] [[kernel]] void copy_gg<{1}, {2}>( - device const {1}* src [[buffer(0)]], - device {2}* dst [[buffer(1)]], - constant const int* src_shape [[buffer(2)]], - constant const int64_t* src_strides [[buffer(3)]], - constant const int64_t* dst_strides [[buffer(4)]], - constant const int& ndim [[buffer(5)]], - uint3 index [[thread_position_in_grid]]); -)"; diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index 346954be4..2c22e9668 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -1,9 +1,7 @@ // Copyright © 2024 Apple Inc. -#include #include "mlx/backend/common/compiled.h" #include "mlx/backend/metal/jit/arange.h" -#include "mlx/backend/metal/jit/copy.h" #include "mlx/backend/metal/jit/gemv_masked.h" #include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/jit/scan.h" @@ -67,7 +65,7 @@ void add_binary_kernels( Dtype out_type, const std::string op, std::ostringstream& kernel_source) { - const std::map kernel_types = { + const std::array, 11> kernel_types = {{ {"ss", "binary_ss"}, {"vs", "binary_vs"}, {"sv", "binary_sv"}, @@ -78,29 +76,16 @@ void add_binary_kernels( {"g1", "binary_g_nd1"}, {"g2", "binary_g_nd2"}, {"g3", "binary_g_nd3"}, - {"g4", "binary_g_nd"}, - {"g5", "binary_g_nd"}, {"gn", "binary_g"}, - }; - for (auto [name, func] : kernel_types) { + }}; + for (auto& [name, func] : kernel_types) { std::string template_def; - if (name == "g4" || name == "g5") { - int dim = std::stoi(name.substr(1)); - template_def = get_template_definition( - name + lib_name, - func, - get_type_string(in_type), - get_type_string(out_type), - op, - dim); - } else { - template_def = get_template_definition( - name + lib_name, - func, - get_type_string(in_type), - get_type_string(out_type), - op); - } + template_def = get_template_definition( + name + lib_name, + func, + get_type_string(in_type), + get_type_string(out_type), + op); kernel_source << template_def; } } @@ -149,27 +134,19 @@ MTL::ComputePipelineState* get_ternary_kernel( auto lib = d.get_library(lib_name); if (lib == nullptr) { std::ostringstream kernel_source; - const std::map kernel_types = { + const std::array, 6> kernel_types = {{ {"v", "ternary_v"}, {"v2", "ternary_v2"}, {"g", "ternary_g"}, {"g1", "ternary_g_nd1"}, {"g2", "ternary_g_nd2"}, {"g3", "ternary_g_nd3"}, - {"g4", "ternary_g_nd"}, - {"g5", "ternary_g_nd"}, - }; + }}; kernel_source << metal::utils() << metal::ternary_ops() << metal::ternary(); - for (auto [name, func] : kernel_types) { + for (auto& [name, func] : kernel_types) { std::string template_def; - if (name == "g4" || name == "g5") { - int dim = std::stoi(name.substr(1)); - template_def = get_template_definition( - name + "_" + lib_name, func, get_type_string(type), op, dim); - } else { - template_def = get_template_definition( - name + "_" + lib_name, func, get_type_string(type), op); - } + template_def = get_template_definition( + name + "_" + lib_name, func, get_type_string(type), op); kernel_source << template_def; } lib = d.get_library(lib_name, kernel_source.str()); @@ -186,12 +163,27 @@ MTL::ComputePipelineState* get_copy_kernel( auto lib = d.get_library(lib_name); if (lib == nullptr) { std::ostringstream kernel_source; - kernel_source << metal::utils() << metal::copy() - << fmt::format( - copy_kernels, - lib_name, - get_type_string(in.dtype()), - get_type_string(out.dtype())); + auto in_type = get_type_string(in.dtype()); + auto out_type = get_type_string(out.dtype()); + kernel_source + << metal::utils() << metal::copy() + << get_template_definition("s_" + lib_name, "copy_s", in_type, out_type) + << get_template_definition("v_" + lib_name, "copy_v", in_type, out_type) + << get_template_definition( + "g1_" + lib_name, "copy_g_nd1", in_type, out_type) + << get_template_definition( + "g2_" + lib_name, "copy_g_nd2", in_type, out_type) + << get_template_definition( + "g3_" + lib_name, "copy_g_nd3", in_type, out_type) + << get_template_definition("g_" + lib_name, "copy_g", in_type, out_type) + << get_template_definition( + "gg1_" + lib_name, "copy_gg_nd1", in_type, out_type) + << get_template_definition( + "gg2_" + lib_name, "copy_gg_nd2", in_type, out_type) + << get_template_definition( + "gg3_" + lib_name, "copy_gg_nd3", in_type, out_type) + << get_template_definition( + "gg_" + lib_name, "copy_gg", in_type, out_type); lib = d.get_library(lib_name, kernel_source.str()); } return d.get_kernel(kernel_name, lib); @@ -296,11 +288,11 @@ MTL::ComputePipelineState* get_mb_sort_kernel( if (lib == nullptr) { std::ostringstream kernel_source; kernel_source << metal::utils() << metal::sort(); - std::vector> kernel_types = { - {"sort_", "mb_block_sort"}, - {"partition_", "mb_block_partition"}, - {"merge_", "mb_block_merge"}}; - for (auto [name, func] : kernel_types) { + std::array, 3> kernel_types = { + {{"sort_", "mb_block_sort"}, + {"partition_", "mb_block_partition"}, + {"merge_", "mb_block_merge"}}}; + for (auto& [name, func] : kernel_types) { kernel_source << get_template_definition( name + lib_name, func, diff --git a/mlx/backend/metal/kernels/arg_reduce.metal b/mlx/backend/metal/kernels/arg_reduce.metal index 61df1273e..fa32dec4f 100644 --- a/mlx/backend/metal/kernels/arg_reduce.metal +++ b/mlx/backend/metal/kernels/arg_reduce.metal @@ -70,16 +70,16 @@ IndexValPair simd_shuffle_down(IndexValPair data, uint16_t delta) { simd_shuffle_down(data.index, delta), simd_shuffle_down(data.val, delta)}; } -template +template [[kernel]] void arg_reduce_general( const device T* in [[buffer(0)]], device uint32_t* out [[buffer(1)]], - const device int* shape [[buffer(2)]], - const device size_t* in_strides [[buffer(3)]], - const device size_t* out_strides [[buffer(4)]], - const device size_t& ndim [[buffer(5)]], - const device size_t& axis_stride [[buffer(6)]], - const device size_t& axis_size [[buffer(7)]], + const constant int* shape [[buffer(2)]], + const constant size_t* in_strides [[buffer(3)]], + const constant size_t* out_strides [[buffer(4)]], + const constant size_t& ndim [[buffer(5)]], + const constant size_t& axis_stride [[buffer(6)]], + const constant size_t& axis_size [[buffer(7)]], uint gid [[thread_position_in_grid]], uint lid [[thread_position_in_threadgroup]], uint lsize [[threads_per_threadgroup]], @@ -159,28 +159,12 @@ template } } -#define instantiate_arg_reduce_helper(name, itype, op) \ - template [[host_name(name)]] [[kernel]] void \ - arg_reduce_general, 4>( \ - const device itype* in [[buffer(0)]], \ - device uint32_t* out [[buffer(1)]], \ - const device int* shape [[buffer(2)]], \ - const device size_t* in_strides [[buffer(3)]], \ - const device size_t* out_strides [[buffer(4)]], \ - const device size_t& ndim [[buffer(5)]], \ - const device size_t& axis_stride [[buffer(6)]], \ - const device size_t& axis_size [[buffer(7)]], \ - uint gid [[thread_position_in_grid]], \ - uint lid [[thread_position_in_threadgroup]], \ - uint lsize [[threads_per_threadgroup]], \ - uint simd_size [[threads_per_simdgroup]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]]); - // clang-format off #define instantiate_arg_reduce(name, itype) \ - instantiate_arg_reduce_helper("argmin_" #name , itype, ArgMin) \ - instantiate_arg_reduce_helper("argmax_" #name , itype, ArgMax) + instantiate_kernel( \ + "argmin_" #name, arg_reduce_general, itype, ArgMin) \ + instantiate_kernel( \ + "argmax_" #name, arg_reduce_general, itype, ArgMax) instantiate_arg_reduce(bool_, bool) instantiate_arg_reduce(uint8, uint8_t) diff --git a/mlx/backend/metal/kernels/binary.h b/mlx/backend/metal/kernels/binary.h index 2e668621b..c5a584b6d 100644 --- a/mlx/backend/metal/kernels/binary.h +++ b/mlx/backend/metal/kernels/binary.h @@ -93,7 +93,7 @@ template uint2 grid_dim [[threads_per_grid]]) { auto a_idx = elem_to_loc_2(index, a_strides); auto b_idx = elem_to_loc_2(index, b_strides); - size_t out_idx = index.x + (size_t)grid_dim.x * index.y; + size_t out_idx = index.x + size_t(grid_dim.x) * index.y; c[out_idx] = Op()(a[a_idx], b[b_idx]); } @@ -109,26 +109,10 @@ template auto a_idx = elem_to_loc_3(index, a_strides); auto b_idx = elem_to_loc_3(index, b_strides); size_t out_idx = - index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); + index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z); c[out_idx] = Op()(a[a_idx], b[b_idx]); } -template -[[kernel]] void binary_g_nd( - device const T* a, - device const T* b, - device U* c, - constant const int shape[DIM], - constant const size_t a_strides[DIM], - constant const size_t b_strides[DIM], - uint3 index [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]) { - auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides); - size_t out_idx = - index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); - c[out_idx] = Op()(a[idx.x], b[idx.y]); -} - template [[kernel]] void binary_g( device const T* a, @@ -141,6 +125,7 @@ template uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim); - size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z); + size_t out_idx = + index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z); c[out_idx] = Op()(a[idx.x], b[idx.y]); } diff --git a/mlx/backend/metal/kernels/binary.metal b/mlx/backend/metal/kernels/binary.metal index 2c302c20b..5600de23e 100644 --- a/mlx/backend/metal/kernels/binary.metal +++ b/mlx/backend/metal/kernels/binary.metal @@ -21,8 +21,6 @@ instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op) \ instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op) \ instantiate_kernel("g3" #op #tname, binary_g_nd3, itype, otype, op) \ - instantiate_kernel("g4" #op #tname, binary_g_nd, itype, otype, op, 4) \ - instantiate_kernel("g5" #op #tname, binary_g_nd, itype, otype, op, 5) #define instantiate_binary_integer(op) \ instantiate_binary_all(op, uint8, uint8_t, uint8_t) \ diff --git a/mlx/backend/metal/kernels/binary_two.h b/mlx/backend/metal/kernels/binary_two.h index 08ff876ca..f40d81e86 100644 --- a/mlx/backend/metal/kernels/binary_two.h +++ b/mlx/backend/metal/kernels/binary_two.h @@ -118,7 +118,7 @@ template uint2 grid_dim [[threads_per_grid]]) { auto a_idx = elem_to_loc_2(index, a_strides); auto b_idx = elem_to_loc_2(index, b_strides); - size_t out_idx = index.x + (size_t)grid_dim.x * index.y; + size_t out_idx = index.x + size_t(grid_dim.x) * index.y; auto out = Op()(a[a_idx], b[b_idx]); c[out_idx] = out[0]; d[out_idx] = out[1]; @@ -137,31 +137,12 @@ template auto a_idx = elem_to_loc_3(index, a_strides); auto b_idx = elem_to_loc_3(index, b_strides); size_t out_idx = - index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); + index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z); auto out = Op()(a[a_idx], b[b_idx]); c[out_idx] = out[0]; d[out_idx] = out[1]; } -template -[[kernel]] void binary_g_nd( - device const T* a, - device const T* b, - device U* c, - device U* d, - constant const int shape[DIM], - constant const size_t a_strides[DIM], - constant const size_t b_strides[DIM], - uint3 index [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]) { - auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides); - size_t out_idx = - index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); - auto out = Op()(a[idx.x], b[idx.y]); - c[out_idx] = out[0]; - d[out_idx] = out[1]; -} - template [[kernel]] void binary_g( device const T* a, @@ -175,7 +156,8 @@ template uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim); - size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z); + size_t out_idx = + index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z); auto out = Op()(a[idx.x], b[idx.y]); c[out_idx] = out[0]; d[out_idx] = out[1]; diff --git a/mlx/backend/metal/kernels/binary_two.metal b/mlx/backend/metal/kernels/binary_two.metal index fb1bd785b..8481776aa 100644 --- a/mlx/backend/metal/kernels/binary_two.metal +++ b/mlx/backend/metal/kernels/binary_two.metal @@ -19,8 +19,6 @@ instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op) \ instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op) \ instantiate_kernel("g3" #op #tname, binary_g_nd3, itype, otype, op) \ - instantiate_kernel("g4" #op #tname, binary_g_nd, itype, otype, op, 4) \ - instantiate_kernel("g5" #op #tname, binary_g_nd, itype, otype, op, 5) #define instantiate_binary_float(op) \ instantiate_binary_all(op, float16, half, half) \ diff --git a/mlx/backend/metal/kernels/copy.h b/mlx/backend/metal/kernels/copy.h index 6ba5ed741..2d836ff65 100644 --- a/mlx/backend/metal/kernels/copy.h +++ b/mlx/backend/metal/kernels/copy.h @@ -71,20 +71,6 @@ template dst[dst_idx] = static_cast(src[src_idx]); } -template -[[kernel]] void copy_g_nd( - device const T* src [[buffer(0)]], - device U* dst [[buffer(1)]], - constant const int* src_shape [[buffer(2)]], - constant const int64_t* src_strides [[buffer(3)]], - uint3 index [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]) { - auto src_idx = elem_to_loc_nd(index, src_shape, src_strides); - int64_t dst_idx = - index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z); - dst[dst_idx] = static_cast(src[src_idx]); -} - template [[kernel]] void copy_g( device const T* src [[buffer(0)]], @@ -136,19 +122,6 @@ template dst[dst_idx] = static_cast(src[src_idx]); } -template -[[kernel]] void copy_gg_nd( - device const T* src [[buffer(0)]], - device U* dst [[buffer(1)]], - constant const int* src_shape [[buffer(2)]], - constant const int64_t* src_strides [[buffer(3)]], - constant const int64_t* dst_strides [[buffer(4)]], - uint3 index [[thread_position_in_grid]]) { - auto src_idx = elem_to_loc_nd(index, src_shape, src_strides); - auto dst_idx = elem_to_loc_nd(index, src_shape, dst_strides); - dst[dst_idx] = static_cast(src[src_idx]); -} - template [[kernel]] void copy_gg( device const T* src [[buffer(0)]], diff --git a/mlx/backend/metal/kernels/copy.metal b/mlx/backend/metal/kernels/copy.metal index a121197e5..76cfbb867 100644 --- a/mlx/backend/metal/kernels/copy.metal +++ b/mlx/backend/metal/kernels/copy.metal @@ -16,10 +16,6 @@ instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype) \ instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype) \ instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype) \ - instantiate_kernel("g4_copy" #tname, copy_g_nd, itype, otype, 4) \ - instantiate_kernel("g5_copy" #tname, copy_g_nd, itype, otype, 5) \ - instantiate_kernel("gg4_copy" #tname, copy_gg_nd, itype, otype, 4) \ - instantiate_kernel("gg5_copy" #tname, copy_gg_nd, itype, otype, 5) \ instantiate_kernel("g_copy" #tname, copy_g, itype, otype) \ instantiate_kernel("gg_copy" #tname, copy_gg, itype, otype) diff --git a/mlx/backend/metal/kernels/random.metal b/mlx/backend/metal/kernels/random.metal index 58073d649..39bc15953 100644 --- a/mlx/backend/metal/kernels/random.metal +++ b/mlx/backend/metal/kernels/random.metal @@ -69,9 +69,9 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) { device char* out, device const bool& odd, device const uint& bytes_per_key, - device const int& ndim, - device const int* key_shape, - device const size_t* key_strides, + constant const int& ndim, + constant const int* key_shape, + constant const size_t* key_strides, uint2 grid_dim [[threads_per_grid]], uint2 index [[thread_position_in_grid]]) { auto kidx = 2 * index.x; diff --git a/mlx/backend/metal/kernels/sort.h b/mlx/backend/metal/kernels/sort.h index eb1f70d19..bfa4d98ea 100644 --- a/mlx/backend/metal/kernels/sort.h +++ b/mlx/backend/metal/kernels/sort.h @@ -342,9 +342,9 @@ template < const constant int& in_stride_sorted_axis [[buffer(3)]], const constant int& out_stride_sorted_axis [[buffer(4)]], const constant int& nc_dim [[buffer(5)]], - const device int* nc_shape [[buffer(6)]], - const device size_t* in_nc_strides [[buffer(7)]], - const device size_t* out_nc_strides [[buffer(8)]], + const constant int* nc_shape [[buffer(6)]], + const constant size_t* in_nc_strides [[buffer(7)]], + const constant size_t* out_nc_strides [[buffer(8)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) { using sort_kernel = @@ -485,8 +485,8 @@ template < const constant int& size_sorted_axis [[buffer(3)]], const constant int& stride_sorted_axis [[buffer(4)]], const constant int& nc_dim [[buffer(5)]], - const device int* nc_shape [[buffer(6)]], - const device size_t* nc_strides [[buffer(7)]], + const constant int* nc_shape [[buffer(6)]], + const constant size_t* nc_strides [[buffer(7)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) { using sort_kernel = KernelMultiBlockMergeSort< diff --git a/mlx/backend/metal/kernels/ternary.h b/mlx/backend/metal/kernels/ternary.h index 370480d35..7cc062500 100644 --- a/mlx/backend/metal/kernels/ternary.h +++ b/mlx/backend/metal/kernels/ternary.h @@ -52,7 +52,7 @@ template auto a_idx = elem_to_loc_2(index, a_strides); auto b_idx = elem_to_loc_2(index, b_strides); auto c_idx = elem_to_loc_2(index, c_strides); - size_t out_idx = index.x + (size_t)grid_dim.x * index.y; + size_t out_idx = index.x + size_t(grid_dim.x) * index.y; d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]); } @@ -71,29 +71,10 @@ template auto b_idx = elem_to_loc_3(index, b_strides); auto c_idx = elem_to_loc_3(index, c_strides); size_t out_idx = - index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); + index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z); d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]); } -template -[[kernel]] void ternary_g_nd( - device const bool* a, - device const T* b, - device const T* c, - device T* d, - constant const int shape[DIM], - constant const size_t a_strides[DIM], - constant const size_t b_strides[DIM], - constant const size_t c_strides[DIM], - uint3 index [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]) { - auto idx = - elem_to_loc_3_nd(index, shape, a_strides, b_strides, c_strides); - size_t out_idx = - index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); - d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]); -} - template [[kernel]] void ternary_g( device const bool* a, @@ -109,6 +90,7 @@ template uint3 grid_dim [[threads_per_grid]]) { auto idx = elem_to_loc_3_nd(index, shape, a_strides, b_strides, c_strides, ndim); - size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z); + size_t out_idx = + index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z); d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]); } diff --git a/mlx/backend/metal/kernels/ternary.metal b/mlx/backend/metal/kernels/ternary.metal index 4101229b9..47894594c 100644 --- a/mlx/backend/metal/kernels/ternary.metal +++ b/mlx/backend/metal/kernels/ternary.metal @@ -16,8 +16,6 @@ instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op) \ instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op) \ instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op) \ - instantiate_kernel("g4_" #op #tname, ternary_g_nd, type, op, 4) \ - instantiate_kernel("g5_" #op #tname, ternary_g_nd, type, op, 5) #define instantiate_ternary_types(op) \ instantiate_ternary_all(op, bool_, bool) \ diff --git a/mlx/backend/metal/kernels/unary.h b/mlx/backend/metal/kernels/unary.h index e904e1629..ecdf34e1d 100644 --- a/mlx/backend/metal/kernels/unary.h +++ b/mlx/backend/metal/kernels/unary.h @@ -22,8 +22,8 @@ template [[kernel]] void unary_g( device const T* in, device T* out, - device const int* in_shape, - device const size_t* in_strides, + constant const int* in_shape, + constant const size_t* in_strides, device const int& ndim, uint index [[thread_position_in_grid]]) { auto idx = elem_to_loc(index, in_shape, in_strides, ndim); diff --git a/mlx/backend/metal/kernels/utils.h b/mlx/backend/metal/kernels/utils.h index 17d71b880..0ec69e191 100644 --- a/mlx/backend/metal/kernels/utils.h +++ b/mlx/backend/metal/kernels/utils.h @@ -83,20 +83,6 @@ struct Limits { /////////////////////////////////////////////////////////////////////////////// // Single Array with generic dims -template -METAL_FUNC stride_t elem_to_loc( - uint elem, - device const int* shape, - device const stride_t* strides, - int ndim) { - stride_t loc = 0; - for (int i = ndim - 1; i >= 0 && elem > 0; --i) { - loc += (elem % shape[i]) * strides[i]; - elem /= shape[i]; - } - return loc; -} - template METAL_FUNC stride_t elem_to_loc( uint elem, @@ -111,20 +97,6 @@ METAL_FUNC stride_t elem_to_loc( return loc; } -template -METAL_FUNC stride_t elem_to_loc( - stride_t elem, - device const int* shape, - device const stride_t* strides, - int ndim) { - stride_t loc = 0; - for (int i = ndim - 1; i >= 0 && elem > 0; --i) { - loc += (elem % shape[i]) * strides[i]; - elem /= shape[i]; - } - return loc; -} - template METAL_FUNC stride_t elem_to_loc( stride_t elem, @@ -174,78 +146,18 @@ elem_to_loc_3(uint3 elem, constant const stride_t strides[3]) { return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0]; } -template -METAL_FUNC size_t elem_to_loc_nd( - uint elem, - device const int* shape, - device const size_t* strides) { - size_t loc = (elem % shape[NDIM - 1]) * strides[NDIM - 1]; - - MLX_MTL_PRAGMA_UNROLL - for (int d = NDIM - 2; d >= 0; --d) { - elem /= shape[d + 1]; - loc += (elem % shape[d]) * strides[d]; - } - - return loc; -} - -template -METAL_FUNC size_t elem_to_loc_nd( - uint3 elem, - constant const int shape[NDIM], - constant const size_t strides[NDIM]) { - size_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2]; - for (int d = NDIM - 3; d >= 0; --d) { - loc += (elem.z % shape[d]) * strides[d]; - elem.z /= shape[d]; - } - return loc; -} - -template -METAL_FUNC int64_t elem_to_loc_nd( - uint elem, - constant const int shape[NDIM], - constant const int64_t strides[NDIM]) { - int64_t loc = (elem % shape[NDIM - 1]) * strides[NDIM - 1]; - - MLX_MTL_PRAGMA_UNROLL - for (int d = NDIM - 2; d >= 0; --d) { - elem /= shape[d + 1]; - loc += (elem % shape[d]) * strides[d]; - } - - return loc; -} - -template -METAL_FUNC int64_t elem_to_loc_nd( - uint3 elem, - constant const int shape[NDIM], - constant const int64_t strides[NDIM]) { - int64_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2]; - for (int d = NDIM - 3; d >= 0; --d) { - loc += (elem.z % shape[d]) * strides[d]; - elem.z /= shape[d]; - } - return loc; -} - /////////////////////////////////////////////////////////////////////////////// // Multiple Arrays with generic dims -METAL_FUNC uint2 elem_to_loc_2_nd( +METAL_FUNC ulong2 elem_to_loc_2_nd( uint3 elem, constant const int* shape, constant const size_t* a_strides, constant const size_t* b_strides, int ndim) { - uint2 loc = { - static_cast( - elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]), - static_cast( - elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])}; + ulong2 loc = { + elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2], + elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2]}; for (int d = ndim - 3; d >= 0; --d) { uint l = elem.z % shape[d]; loc.x += l * a_strides[d]; @@ -255,20 +167,17 @@ METAL_FUNC uint2 elem_to_loc_2_nd( return loc; } -METAL_FUNC uint3 elem_to_loc_3_nd( +METAL_FUNC ulong3 elem_to_loc_3_nd( uint3 elem, constant const int* shape, constant const size_t* a_strides, constant const size_t* b_strides, constant const size_t* c_strides, int ndim) { - uint3 loc = { - static_cast( - elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]), - static_cast( - elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2]), - static_cast( - elem.x * c_strides[ndim - 1] + elem.y * c_strides[ndim - 2])}; + ulong3 loc = { + elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2], + elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2], + elem.x * c_strides[ndim - 1] + elem.y * c_strides[ndim - 2]}; for (int d = ndim - 3; d >= 0; --d) { uint l = elem.z % shape[d]; loc.x += l * a_strides[d]; @@ -279,53 +188,6 @@ METAL_FUNC uint3 elem_to_loc_3_nd( return loc; } -/////////////////////////////////////////////////////////////////////////////// -// Multiple Arrays with fixed N dims - -template -METAL_FUNC uint2 elem_to_loc_2_nd( - uint3 elem, - constant const int shape[NDIM], - constant const size_t a_strides[NDIM], - constant const size_t b_strides[NDIM]) { - uint2 loc = { - static_cast( - elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]), - static_cast( - elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2])}; - for (int d = NDIM - 3; d >= 0; --d) { - uint l = elem.z % shape[d]; - loc.x += l * a_strides[d]; - loc.y += l * b_strides[d]; - elem.z /= shape[d]; - } - return loc; -} - -template -METAL_FUNC uint3 elem_to_loc_3_nd( - uint3 elem, - constant const int shape[NDIM], - constant const size_t a_strides[NDIM], - constant const size_t b_strides[NDIM], - constant const size_t c_strides[NDIM]) { - uint3 loc = { - static_cast( - elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]), - static_cast( - elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2]), - static_cast( - elem.x * c_strides[NDIM - 1] + elem.y * c_strides[NDIM - 2])}; - for (int d = NDIM - 3; d >= 0; --d) { - uint l = elem.z % shape[d]; - loc.x += l * a_strides[d]; - loc.y += l * b_strides[d]; - loc.z += l * c_strides[d]; - elem.z /= shape[d]; - } - return loc; -} - /////////////////////////////////////////////////////////////////////////////// // Elem to loc in a loop utils /////////////////////////////////////////////////////////////////////////////// diff --git a/mlx/backend/metal/ternary.cpp b/mlx/backend/metal/ternary.cpp index c214db267..3c109018b 100644 --- a/mlx/backend/metal/ternary.cpp +++ b/mlx/backend/metal/ternary.cpp @@ -8,7 +8,7 @@ namespace mlx::core { -constexpr int MAX_TERNARY_SPECIALIZED_DIMS = 5; +constexpr int MAX_TERNARY_SPECIALIZED_DIMS = 3; void ternary_op_gpu_inplace( const std::vector& inputs, @@ -26,11 +26,21 @@ void ternary_op_gpu_inplace( } // Try to collapse contiguous dims - auto [shape, strides] = collapse_contiguous_dims(a, b, c, out); - auto& strides_a = strides[0]; - auto& strides_b = strides[1]; - auto& strides_c = strides[2]; - auto& strides_out = strides[3]; + auto maybe_collapse = [topt, &a, &b, &c, &out]() { + if (topt == TernaryOpType::General) { + // The size cap here should ideally be `UINT32_MAX` but we are + // limitied by the shape being an int. + auto [shape, strides] = collapse_contiguous_dims( + {a, b, c, out}, + /* size_cap = */ INT32_MAX); + return std::make_tuple( + shape, strides[0], strides[1], strides[2], strides[3]); + } else { + std::vector e; + return std::make_tuple(std::vector{}, e, e, e, e); + } + }; + auto [shape, strides_a, strides_b, strides_c, strides_out] = maybe_collapse(); bool use_2d = out.data_size() > UINT_MAX; std::string kernel_name; @@ -88,7 +98,7 @@ void ternary_op_gpu_inplace( size_t rest = out.size() / (dim0 * dim1); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); if (thread_group_size != 1024) { - throw std::runtime_error("[Metal::binary] Must use 1024 sized block"); + throw std::runtime_error("[Metal::ternary] Must use 1024 sized block"); } MTL::Size group_dims = get_block_dims(dim0, dim1, rest); MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); diff --git a/mlx/backend/no_cpu/CMakeLists.txt b/mlx/backend/no_cpu/CMakeLists.txt index 6c5f8017a..029c720d7 100644 --- a/mlx/backend/no_cpu/CMakeLists.txt +++ b/mlx/backend/no_cpu/CMakeLists.txt @@ -6,4 +6,5 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/../common/compiled.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../common/compiled_nocpu.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../common/reduce_utils.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/../common/slicing.cpp) + ${CMAKE_CURRENT_SOURCE_DIR}/../common/slicing.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../common/utils.cpp) diff --git a/python/src/fast.cpp b/python/src/fast.cpp index 92e6e6bdc..758a27530 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -243,7 +243,7 @@ void init_fast(nb::module_& parent_module) { template_args.emplace_back(name, dtype); } else { throw std::invalid_argument( - "[[metal_kernel]] Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`."); + "[metal_kernel] Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`."); } } } @@ -271,25 +271,24 @@ void init_fast(nb::module_& parent_module) { nb::sig( "def __call__(self, *, inputs: List[Union[scalar, array]], output_shapes: List[Sequence[int]], output_dtypes: List[Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[List[Tuple[str, Union[bool, int, Dtype]]]] = None, init_value: Optional[float] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"), R"pbdoc( - Run the kernel. + Run the kernel. - Args: - inputs (List[array]): The inputs passed to the Metal kernel. - output_shapes (List[Sequence[int]]): The list of shapes for each output in ``output_names``. - output_dtypes (List[Dtype]): The list of data types for each output in ``output_names``. - grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with. - threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use. - template (List[Tuple[str, Union[bool, int, Dtype]]], optional): Template arguments. - These will be added as template arguments to the kernel definition. Default: ``None``. - init_value (float, optional): Optional value to use to initialize all of the output arrays. - By default, output arrays are uninitialized. Default: ``None``. - verbose (bool, optional): Whether to print the full generated source code of the kernel - when it is run. Default: ``False``. - stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``. + Args: + inputs (List[array]): The inputs passed to the Metal kernel. + output_shapes (List[Sequence[int]]): The list of shapes for each output in ``output_names``. + output_dtypes (List[Dtype]): The list of data types for each output in ``output_names``. + grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with. + threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use. + template (List[Tuple[str, Union[bool, int, Dtype]]], optional): Template arguments. + These will be added as template arguments to the kernel definition. Default: ``None``. + init_value (float, optional): Optional value to use to initialize all of the output arrays. + By default, output arrays are uninitialized. Default: ``None``. + verbose (bool, optional): Whether to print the full generated source code of the kernel + when it is run. Default: ``False``. + stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``. - Returns: - List[array]: The list of output arrays. - )pbdoc"); + Returns: + List[array]: The list of output arrays.)pbdoc"); }, "name"_a, "input_names"_a, @@ -306,16 +305,16 @@ void init_fast(nb::module_& parent_module) { input_names (List[str]): The parameter names of the inputs in the function signature. output_names (List[str]): The parameter names of the outputs in the - function signature. + function signature. source (str): Source code. This is the body of a function in Metal, - the function signature will be automatically generated. + the function signature will be automatically generated. header (str): Header source code to include before the main function. - Useful for helper functions or includes that should live outside of - the main function body. + Useful for helper functions or includes that should live outside of + the main function body. ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous - before the kernel runs. Default: ``True``. + before the kernel runs. Default: ``True``. atomic_outputs (bool): Whether to use atomic outputs in the function signature - e.g. ``device atomic``. Default: ``False``. + e.g. ``device atomic``. Default: ``False``. Returns: Callable ``metal_kernel``.