From 4f9f9ebb6f8da21e5aabf3f0cad1196f1cbe6536 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 25 Sep 2024 12:07:43 -0700 Subject: [PATCH] Faster Metal unary and binary for general case (#1431) * faster unary and binary for general case * update ternary + jit fix * fix jit * unary work per thread --- mlx/backend/metal/binary.cpp | 29 +++++++----- mlx/backend/metal/jit_kernels.cpp | 34 +++++++++----- mlx/backend/metal/kernels/binary.h | 16 +++++-- mlx/backend/metal/kernels/binary.metal | 25 ++++++----- mlx/backend/metal/kernels/binary_two.h | 20 ++++++--- mlx/backend/metal/kernels/binary_two.metal | 25 ++++++----- mlx/backend/metal/kernels/ternary.h | 24 +++++++--- mlx/backend/metal/kernels/ternary.metal | 1 + mlx/backend/metal/kernels/unary.h | 17 +++++-- mlx/backend/metal/kernels/unary.metal | 9 ++-- mlx/backend/metal/ternary.cpp | 24 +++++----- mlx/backend/metal/unary.cpp | 52 +++++++++++++++++----- 12 files changed, 183 insertions(+), 93 deletions(-) diff --git a/mlx/backend/metal/binary.cpp b/mlx/backend/metal/binary.cpp index 59a661fc8..248fb526c 100644 --- a/mlx/backend/metal/binary.cpp +++ b/mlx/backend/metal/binary.cpp @@ -19,14 +19,13 @@ namespace mlx::core { -constexpr int MAX_BINARY_SPECIALIZED_DIMS = 3; - std::string get_kernel_name( BinaryOpType bopt, const std::string& op, const array& a, bool use_2d, - int ndim) { + int ndim, + int work_per_thread) { std::ostringstream kname; switch (bopt) { case BinaryOpType::ScalarScalar: @@ -43,14 +42,17 @@ std::string get_kernel_name( break; case BinaryOpType::General: kname << "g"; - if (ndim <= MAX_BINARY_SPECIALIZED_DIMS) { + if (ndim <= 3) { kname << ndim; } else { kname << "n"; + if (work_per_thread > 1) { + kname << work_per_thread; + } } break; } - kname << op << type_to_name(a); + kname << "_" << op << type_to_name(a); return kname.str(); } @@ -85,7 +87,11 @@ void binary_op_gpu_inplace( auto [shape, strides_a, strides_b, strides_out] = maybe_collapse(); bool use_2d = out.data_size() > UINT32_MAX; - std::string kernel_name = get_kernel_name(bopt, op, a, use_2d, shape.size()); + auto ndim = shape.size(); + int work_per_thread = + (bopt == BinaryOpType::General && shape[ndim - 1] > 4) ? 4 : 1; + std::string kernel_name = + get_kernel_name(bopt, op, a, use_2d, shape.size(), work_per_thread); auto& d = metal::device(s.device); auto kernel = outputs.size() == 2 @@ -110,7 +116,11 @@ void binary_op_gpu_inplace( } if (bopt == BinaryOpType::General) { - auto ndim = shape.size(); + // Launch up to 3D grid of threads + size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; + size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; + size_t rest = out.size() / (dim0 * dim1); + if (ndim > 3) { compute_encoder->setBytes(shape.data(), ndim * sizeof(int), arg_idx++); compute_encoder->setBytes( @@ -118,6 +128,7 @@ void binary_op_gpu_inplace( compute_encoder->setBytes( strides_b.data(), ndim * sizeof(size_t), arg_idx++); compute_encoder->setBytes(&ndim, sizeof(int), arg_idx++); + dim0 = (dim0 + work_per_thread - 1) / work_per_thread; } else { // The shape is implicit in the grid for <= 3D compute_encoder->setBytes( @@ -126,10 +137,6 @@ void binary_op_gpu_inplace( strides_b.data(), ndim * sizeof(size_t), arg_idx++); } - // Launch up to 3D grid of threads - size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; - size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; - size_t rest = out.size() / (dim0 * dim1); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); if (thread_group_size != 1024) { throw std::runtime_error("[Metal::binary] Must use 1024 sized block"); diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index 74957f150..37e301142 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -42,18 +42,19 @@ MTL::ComputePipelineState* get_unary_kernel( const std::string& kernel_name, Dtype out_type, const std::string op) { - std::string lib_name = kernel_name.substr(1); + std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name); if (lib == nullptr) { std::ostringstream kernel_source; - auto u_def = get_template_definition( - "v" + lib_name, "unary_v", get_type_string(out_type), op); - auto u2_def = get_template_definition( - "v2" + lib_name, "unary_v2", get_type_string(out_type), op); - auto g_def = get_template_definition( - "g" + lib_name, "unary_g", get_type_string(out_type), op); - kernel_source << metal::utils() << metal::unary_ops() << metal::unary() - << u_def << u2_def << g_def; + kernel_source << metal::utils() << metal::unary_ops() << metal::unary(); + kernel_source << get_template_definition( + "v_" + lib_name, "unary_v", get_type_string(out_type), op); + kernel_source << get_template_definition( + "v2_" + lib_name, "unary_v2", get_type_string(out_type), op); + kernel_source << get_template_definition( + "g_" + lib_name, "unary_g", get_type_string(out_type), op); + kernel_source << get_template_definition( + "gn4_" + lib_name, "unary_g", get_type_string(out_type), op, 4); lib = d.get_library(lib_name, kernel_source.str()); } return d.get_kernel(kernel_name, lib); @@ -81,13 +82,20 @@ void add_binary_kernels( for (auto& [name, func] : kernel_types) { std::string template_def; template_def = get_template_definition( - name + lib_name, + name + "_" + lib_name, func, get_type_string(in_type), get_type_string(out_type), op); kernel_source << template_def; } + kernel_source << get_template_definition( + "gn4_" + lib_name, + "binary_g", + get_type_string(in_type), + get_type_string(out_type), + op, + 4); } MTL::ComputePipelineState* get_binary_kernel( @@ -96,7 +104,7 @@ MTL::ComputePipelineState* get_binary_kernel( Dtype in_type, Dtype out_type, const std::string op) { - std::string lib_name = kernel_name.substr(2); + std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name); if (lib == nullptr) { std::ostringstream kernel_source; @@ -113,7 +121,7 @@ MTL::ComputePipelineState* get_binary_two_kernel( Dtype in_type, Dtype out_type, const std::string op) { - std::string lib_name = kernel_name.substr(2); + std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name); if (lib == nullptr) { std::ostringstream kernel_source; @@ -149,6 +157,8 @@ MTL::ComputePipelineState* get_ternary_kernel( name + "_" + lib_name, func, get_type_string(type), op); kernel_source << template_def; } + kernel_source << get_template_definition( + "gn4_" + lib_name, "ternary_g", get_type_string(type), op, 4); lib = d.get_library(lib_name, kernel_source.str()); } return d.get_kernel(kernel_name, lib); diff --git a/mlx/backend/metal/kernels/binary.h b/mlx/backend/metal/kernels/binary.h index c5a584b6d..d64488e9f 100644 --- a/mlx/backend/metal/kernels/binary.h +++ b/mlx/backend/metal/kernels/binary.h @@ -113,7 +113,7 @@ template c[out_idx] = Op()(a[a_idx], b[b_idx]); } -template +template [[kernel]] void binary_g( device const T* a, device const T* b, @@ -124,8 +124,16 @@ template constant const int& ndim, uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { - auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim); + auto idx = elem_to_loc_2_nd( + {N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim); + auto xshape = shape[ndim - 1]; size_t out_idx = - index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z); - c[out_idx] = Op()(a[idx.x], b[idx.y]); + N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z); + auto a_xstride = a_strides[ndim - 1]; + auto b_xstride = b_strides[ndim - 1]; + for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { + c[out_idx++] = Op()(a[idx.x], b[idx.y]); + idx.x += a_xstride; + idx.y += b_xstride; + } } diff --git a/mlx/backend/metal/kernels/binary.metal b/mlx/backend/metal/kernels/binary.metal index 5600de23e..5c437bd2a 100644 --- a/mlx/backend/metal/kernels/binary.metal +++ b/mlx/backend/metal/kernels/binary.metal @@ -9,18 +9,19 @@ #include "mlx/backend/metal/kernels/binary_ops.h" #include "mlx/backend/metal/kernels/binary.h" -#define instantiate_binary_all(op, tname, itype, otype) \ - instantiate_kernel("ss" #op #tname, binary_ss, itype, otype, op) \ - instantiate_kernel("sv" #op #tname, binary_sv, itype, otype, op) \ - instantiate_kernel("vs" #op #tname, binary_vs, itype, otype, op) \ - instantiate_kernel("vv" #op #tname, binary_vv, itype, otype, op) \ - instantiate_kernel("sv2" #op #tname, binary_sv2, itype, otype, op) \ - instantiate_kernel("vs2" #op #tname, binary_vs2, itype, otype, op) \ - instantiate_kernel("vv2" #op #tname, binary_vv2, itype, otype, op) \ - instantiate_kernel("gn" #op #tname, binary_g, itype, otype, op) \ - instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op) \ - instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op) \ - instantiate_kernel("g3" #op #tname, binary_g_nd3, itype, otype, op) \ +#define instantiate_binary_all(op, tname, itype, otype) \ + instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \ + instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \ + instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \ + instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \ + instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \ + instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \ + instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \ + instantiate_kernel("gn_" #op #tname, binary_g, itype, otype, op) \ + instantiate_kernel("gn4_" #op #tname, binary_g, itype, otype, op, 4) \ + instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \ + instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op) \ + instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op) \ #define instantiate_binary_integer(op) \ instantiate_binary_all(op, uint8, uint8_t, uint8_t) \ diff --git a/mlx/backend/metal/kernels/binary_two.h b/mlx/backend/metal/kernels/binary_two.h index f40d81e86..a4a3130bf 100644 --- a/mlx/backend/metal/kernels/binary_two.h +++ b/mlx/backend/metal/kernels/binary_two.h @@ -143,7 +143,7 @@ template d[out_idx] = out[1]; } -template +template [[kernel]] void binary_g( device const T* a, device const T* b, @@ -155,10 +155,18 @@ template constant const int& ndim, uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { - auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim); + auto idx = elem_to_loc_2_nd( + {N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim); + auto xshape = shape[ndim - 1]; size_t out_idx = - index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z); - auto out = Op()(a[idx.x], b[idx.y]); - c[out_idx] = out[0]; - d[out_idx] = out[1]; + N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z); + auto a_xstride = a_strides[ndim - 1]; + auto b_xstride = b_strides[ndim - 1]; + for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { + auto out = Op()(a[idx.x], b[idx.y]); + c[out_idx] = out[0]; + d[out_idx++] = out[1]; + idx.x += a_xstride; + idx.y += b_xstride; + } } diff --git a/mlx/backend/metal/kernels/binary_two.metal b/mlx/backend/metal/kernels/binary_two.metal index 8481776aa..f062439ec 100644 --- a/mlx/backend/metal/kernels/binary_two.metal +++ b/mlx/backend/metal/kernels/binary_two.metal @@ -7,18 +7,19 @@ #include "mlx/backend/metal/kernels/binary_ops.h" #include "mlx/backend/metal/kernels/binary_two.h" -#define instantiate_binary_all(op, tname, itype, otype) \ - instantiate_kernel("ss" #op #tname, binary_ss, itype, otype, op) \ - instantiate_kernel("sv" #op #tname, binary_sv, itype, otype, op) \ - instantiate_kernel("vs" #op #tname, binary_vs, itype, otype, op) \ - instantiate_kernel("vv" #op #tname, binary_vv, itype, otype, op) \ - instantiate_kernel("sv2" #op #tname, binary_sv2, itype, otype, op) \ - instantiate_kernel("vs2" #op #tname, binary_vs2, itype, otype, op) \ - instantiate_kernel("vv2" #op #tname, binary_vv2, itype, otype, op) \ - instantiate_kernel("gn" #op #tname, binary_g, itype, otype, op) \ - instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op) \ - instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op) \ - instantiate_kernel("g3" #op #tname, binary_g_nd3, itype, otype, op) \ +#define instantiate_binary_all(op, tname, itype, otype) \ + instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \ + instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \ + instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \ + instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \ + instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \ + instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \ + instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \ + instantiate_kernel("gn_" #op #tname, binary_g, itype, otype, op) \ + instantiate_kernel("gn4_" #op #tname, binary_g, itype, otype, op, 4) \ + instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \ + instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op) \ + instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op) \ #define instantiate_binary_float(op) \ instantiate_binary_all(op, float16, half, half) \ diff --git a/mlx/backend/metal/kernels/ternary.h b/mlx/backend/metal/kernels/ternary.h index 7cc062500..2bd1242c9 100644 --- a/mlx/backend/metal/kernels/ternary.h +++ b/mlx/backend/metal/kernels/ternary.h @@ -75,7 +75,7 @@ template d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]); } -template +template [[kernel]] void ternary_g( device const bool* a, device const T* b, @@ -88,9 +88,23 @@ template constant const int& ndim, uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { - auto idx = - elem_to_loc_3_nd(index, shape, a_strides, b_strides, c_strides, ndim); + auto idx = elem_to_loc_3_nd( + {N * index.x, index.y, index.z}, + shape, + a_strides, + b_strides, + c_strides, + ndim); + auto xshape = shape[ndim - 1]; size_t out_idx = - index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z); - d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]); + N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z); + auto a_xstride = a_strides[ndim - 1]; + auto b_xstride = b_strides[ndim - 1]; + auto c_xstride = c_strides[ndim - 1]; + for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { + d[out_idx++] = Op()(a[idx.x], b[idx.y], c[idx.z]); + idx.x += a_xstride; + idx.y += b_xstride; + idx.z += c_xstride; + } } diff --git a/mlx/backend/metal/kernels/ternary.metal b/mlx/backend/metal/kernels/ternary.metal index 47894594c..79e427775 100644 --- a/mlx/backend/metal/kernels/ternary.metal +++ b/mlx/backend/metal/kernels/ternary.metal @@ -13,6 +13,7 @@ instantiate_kernel("v_" #op #tname, ternary_v, type, op) \ instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \ instantiate_kernel("g_" #op #tname, ternary_g, type, op) \ + instantiate_kernel("gn4_" #op #tname, ternary_g, type, op, 4) \ instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op) \ instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op) \ instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op) \ diff --git a/mlx/backend/metal/kernels/unary.h b/mlx/backend/metal/kernels/unary.h index ecdf34e1d..8d404ae25 100644 --- a/mlx/backend/metal/kernels/unary.h +++ b/mlx/backend/metal/kernels/unary.h @@ -18,14 +18,23 @@ template out[offset] = Op()(in[offset]); } -template +template [[kernel]] void unary_g( device const T* in, device T* out, constant const int* in_shape, constant const size_t* in_strides, device const int& ndim, - uint index [[thread_position_in_grid]]) { - auto idx = elem_to_loc(index, in_shape, in_strides, ndim); - out[index] = Op()(in[idx]); + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto idx = + elem_to_loc({N * index.x, index.y, index.z}, in_shape, in_strides, ndim); + auto xshape = in_shape[ndim - 1]; + auto xstride = in_strides[ndim - 1]; + size_t out_idx = + N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z); + for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { + out[out_idx++] = Op()(in[idx]); + idx += xstride; + } } diff --git a/mlx/backend/metal/kernels/unary.metal b/mlx/backend/metal/kernels/unary.metal index 01eaab512..0c1b5d9e1 100644 --- a/mlx/backend/metal/kernels/unary.metal +++ b/mlx/backend/metal/kernels/unary.metal @@ -5,10 +5,11 @@ #include "mlx/backend/metal/kernels/unary_ops.h" #include "mlx/backend/metal/kernels/unary.h" -#define instantiate_unary_all(op, tname, type) \ - instantiate_kernel("v" #op #tname, unary_v, type, op) \ - instantiate_kernel("v2" #op #tname, unary_v2, type, op) \ - instantiate_kernel("g" #op #tname, unary_g, type, op) +#define instantiate_unary_all(op, tname, type) \ + instantiate_kernel("v_" #op #tname, unary_v, type, op) \ + instantiate_kernel("v2_" #op #tname, unary_v2, type, op) \ + instantiate_kernel("gn4_" #op #tname, unary_g, type, op, 4) \ + instantiate_kernel("g_" #op #tname, unary_g, type, op) #define instantiate_unary_float(op) \ instantiate_unary_all(op, float16, half) \ diff --git a/mlx/backend/metal/ternary.cpp b/mlx/backend/metal/ternary.cpp index 3c109018b..c70b5e969 100644 --- a/mlx/backend/metal/ternary.cpp +++ b/mlx/backend/metal/ternary.cpp @@ -8,8 +8,6 @@ namespace mlx::core { -constexpr int MAX_TERNARY_SPECIALIZED_DIMS = 3; - void ternary_op_gpu_inplace( const std::vector& inputs, array& out, @@ -43,13 +41,18 @@ void ternary_op_gpu_inplace( auto [shape, strides_a, strides_b, strides_c, strides_out] = maybe_collapse(); bool use_2d = out.data_size() > UINT_MAX; + auto ndim = shape.size(); + int work_per_thread = + (topt == TernaryOpType::General && shape[ndim - 1] > 4) ? 4 : 1; std::string kernel_name; { std::ostringstream kname; if (topt == TernaryOpType::General) { kname << "g"; - if (shape.size() <= MAX_TERNARY_SPECIALIZED_DIMS) { + if (shape.size() <= 3) { kname << shape.size(); + } else if (work_per_thread > 1) { + kname << "n" << work_per_thread; } } else if (use_2d) { kname << "v2"; @@ -75,16 +78,19 @@ void ternary_op_gpu_inplace( compute_encoder.set_output_array(out, 3); if (topt == TernaryOpType::General) { - auto ndim = shape.size(); + // Launch up to 3D grid of threads + size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; + size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; + size_t rest = out.size() / (dim0 * dim1); + if (ndim > 3) { compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4); compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5); compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6); compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 7); - if (ndim > MAX_TERNARY_SPECIALIZED_DIMS) { - compute_encoder->setBytes(&ndim, sizeof(int), 8); - } + compute_encoder->setBytes(&ndim, sizeof(int), 8); + dim0 = (dim0 + work_per_thread - 1) / work_per_thread; } else { // The shape is implicit in the grid for <= 3D compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4); @@ -92,10 +98,6 @@ void ternary_op_gpu_inplace( compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 6); } - // Launch up to 3D grid of threads - size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; - size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; - size_t rest = out.size() / (dim0 * dim1); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); if (thread_group_size != 1024) { throw std::runtime_error("[Metal::ternary] Must use 1024 sized block"); diff --git a/mlx/backend/metal/unary.cpp b/mlx/backend/metal/unary.cpp index 17ff1f7b3..666739d3a 100644 --- a/mlx/backend/metal/unary.cpp +++ b/mlx/backend/metal/unary.cpp @@ -1,5 +1,6 @@ // Copyright © 2024 Apple Inc. +#include "mlx/backend/common/utils.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" @@ -25,33 +26,60 @@ void unary_op_gpu_inplace( auto& d = metal::device(s.device); + auto maybe_collapse = [contig, &in, &out]() { + if (!contig) { + auto [shape, strides] = collapse_contiguous_dims( + {in, out}, + /* size_cap = */ INT32_MAX); + return std::make_pair(shape, strides[0]); + } else { + return std::make_pair(std::vector{}, std::vector{}); + } + }; + auto [shape, strides] = maybe_collapse(); + int ndim = shape.size(); + int work_per_thread = (!contig && shape[ndim - 1] > 4) ? 4 : 1; size_t nthreads = contig ? in.data_size() : in.size(); bool use_2d = nthreads > UINT32_MAX; - std::string kernel_name = - (contig ? (use_2d ? "v2" : "v") : "g") + op + type_to_name(out); + std::string kernel_name; + if (contig) { + kernel_name = (use_2d ? "v2" : "v"); + } else { + kernel_name = (work_per_thread == 4 ? "gn4" : "g"); + } + kernel_name += "_" + op + type_to_name(out); auto kernel = get_unary_kernel(d, kernel_name, out.dtype(), op); MTL::Size grid_dims = use_2d ? get_2d_grid_dims(in.shape(), in.strides()) : MTL::Size(nthreads, 1, 1); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); - if (thread_group_size > nthreads) { - thread_group_size = nthreads; - } - MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder->setComputePipelineState(kernel); compute_encoder.set_input_array( in.data_shared_ptr() == nullptr ? out : in, 0); compute_encoder.set_output_array(out, 1); if (!contig) { - compute_encoder->setBytes(in.shape().data(), in.ndim() * sizeof(int), 2); - compute_encoder->setBytes( - in.strides().data(), in.ndim() * sizeof(size_t), 3); - int ndim = in.ndim(); + // Launch up to 3D grid of threads + size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; + size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; + size_t rest = out.size() / (dim0 * dim1); + compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 2); + compute_encoder->setBytes(strides.data(), ndim * sizeof(size_t), 3); compute_encoder->setBytes(&ndim, sizeof(int), 4); + if (thread_group_size != 1024) { + throw std::runtime_error("[Metal::unary] Must use 1024 sized block"); + } + dim0 = (dim0 + work_per_thread - 1) / work_per_thread; + auto group_dims = get_block_dims(dim0, dim1, rest); + MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); + compute_encoder.dispatchThreads(grid_dims, group_dims); + } else { + if (thread_group_size > nthreads) { + thread_group_size = nthreads; + } + MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); + compute_encoder.dispatchThreads(grid_dims, group_dims); } - compute_encoder.dispatchThreads(grid_dims, group_dims); } void unary_op_gpu(