From 825124af8ffd32d0f2f7d8f8eca83c8c3eb510a7 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 5 May 2025 06:15:04 -0700 Subject: [PATCH] fix bw for elementwise ops (#2151) * fix bw for elementwise ops * add compile * fix * fix * fix * fix --- mlx/backend/metal/binary.cpp | 15 ++++-- mlx/backend/metal/compiled.cpp | 38 +++++++++---- mlx/backend/metal/copy.cpp | 27 +++++++--- mlx/backend/metal/kernels/binary.h | 51 ++++++++++++------ mlx/backend/metal/kernels/binary_two.h | 75 ++++++++++++++++---------- mlx/backend/metal/kernels/copy.h | 34 ++++++++---- mlx/backend/metal/kernels/ternary.h | 17 ++++-- mlx/backend/metal/kernels/unary.h | 17 ++++-- mlx/backend/metal/kernels/utils.h | 8 +++ mlx/backend/metal/ternary.cpp | 14 +++-- mlx/backend/metal/unary.cpp | 17 ++++-- mlx/backend/metal/utils.h | 8 +++ 12 files changed, 232 insertions(+), 89 deletions(-) diff --git a/mlx/backend/metal/binary.cpp b/mlx/backend/metal/binary.cpp index f80f8c3e4..c3c67e4d5 100644 --- a/mlx/backend/metal/binary.cpp +++ b/mlx/backend/metal/binary.cpp @@ -90,7 +90,7 @@ void binary_op_gpu_inplace( work_per_thread = large ? 4 : 2; } else { large = out.data_size() > UINT32_MAX; - work_per_thread = 1; + work_per_thread = get_work_per_thread(a.dtype()); } std::string kernel_name = get_kernel_name(bopt, op, a, large, shape.size(), work_per_thread); @@ -137,13 +137,20 @@ void binary_op_gpu_inplace( compute_encoder.dispatch_threads(grid_dims, group_dims); } else { // Launch a 1D or 2D grid of threads - size_t nthreads = out.data_size(); + size_t nthreads = ceildiv(out.data_size(), work_per_thread); if (thread_group_size > nthreads) { thread_group_size = nthreads; } + MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides()) - : MTL::Size(nthreads, 1, 1); + MTL::Size grid_dims; + if (large) { + compute_encoder.set_bytes(out.data_size(), arg_idx++); + grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread); + } else { + compute_encoder.set_bytes(out.data_size(), arg_idx++); + grid_dims = MTL::Size(nthreads, 1, 1); + } compute_encoder.dispatch_threads(grid_dims, group_dims); } } diff --git a/mlx/backend/metal/compiled.cpp b/mlx/backend/metal/compiled.cpp index 154273233..db20f938c 100644 --- a/mlx/backend/metal/compiled.cpp +++ b/mlx/backend/metal/compiled.cpp @@ -64,6 +64,7 @@ inline void build_kernel( cnt++); } + std::string idx_type = use_big_index ? "int64_t" : "uint"; if (add_indices) { os += fmt::format( " constant const int64_t* in_strides [[buffer({0})]],\n", cnt++); @@ -83,6 +84,9 @@ inline void build_kernel( " constant const int64_t* output_strides [[buffer({0})]],\n", cnt++); os += fmt::format( " constant const int* output_shape [[buffer({0})]],\n", cnt++); + } else { + os += fmt::format( + " constant const {0}& size [[buffer({1})]],\n", idx_type, cnt++); } if (dynamic_dims) { os += fmt::format(" constant const int& ndim [[buffer({0})]],\n", cnt++); @@ -92,13 +96,14 @@ inline void build_kernel( os += " uint3 pos [[thread_position_in_grid]],\n"; os += " uint3 grid [[threads_per_grid]]) {\n"; - std::string idx_type = use_big_index ? "int64_t" : "uint"; + os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread); if (contiguous && use_big_index) { // This is only used for contiguous kernels which don't have // a third grid dimension - os += " int64_t index = pos.x + grid.x * int64_t(pos.y);\n"; + os += " int64_t index = N_ * (pos.x + grid.x * int64_t(pos.y));\n"; + } else if (contiguous) { + os += " uint index = N_ * pos.x;\n"; } else if (work_per_thread > 1) { - os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread); os += fmt::format( " int xshape = output_shape[{0}];\n", dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1)); @@ -110,6 +115,9 @@ inline void build_kernel( " {0} index = pos.x + grid.x * (pos.y + {0}(grid.y) * pos.z);\n", idx_type); } + if (work_per_thread > 1 && contiguous) { + os += " for (int i = 0; i < N_ && index < size; ++i) {\n"; + } // Read constant / contiguous inputs in tmps std::vector nc_inputs; @@ -193,7 +201,7 @@ inline void build_kernel( } // Open per-thread loop - if (work_per_thread > 1) { + if (work_per_thread > 1 && !contiguous) { os += " for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n"; } @@ -272,6 +280,7 @@ void Compiled::eval_gpu( auto& s = stream(); auto& d = metal::device(s.device); auto lib = d.get_library(kernel_lib_, [&]() { + int work_per_thread = get_work_per_thread(outputs_[0].dtype()); std::string kernel = metal::utils(); concatenate( kernel, metal::unary_ops(), metal::binary_ops(), metal::ternary_ops()); @@ -284,7 +293,9 @@ void Compiled::eval_gpu( constant_ids_, /* contiguous = */ true, /* ndim = */ 0, - /* dynamic_dims = */ false); + /* dynamic_dims = */ false, + /* use_big_index = */ false, + /* work_per_thread = */ work_per_thread); build_kernel( kernel, kernel_lib_ + "_contiguous_large", @@ -295,7 +306,8 @@ void Compiled::eval_gpu( /* contiguous = */ true, /* ndim = */ 0, /* dynamic_dims = */ false, - /* use_big_index = */ true); + /* use_big_index = */ true, + /* work_per_thread = */ work_per_thread); for (int i = 1; i < 8; i++) { build_kernel( kernel, @@ -468,6 +480,13 @@ void Compiled::eval_gpu( if (!contiguous) { compute_encoder.set_vector_bytes(strides[0], cnt++); compute_encoder.set_vector_bytes(shape, cnt++); + } else { + auto size = outputs[0].data_size(); + if (large) { + compute_encoder.set_bytes(size, cnt++); + } else { + compute_encoder.set_bytes(size, cnt++); + } } // Put the number of dims in if it is dynamic @@ -477,12 +496,13 @@ void Compiled::eval_gpu( // Launch the kernel if (contiguous) { - size_t nthreads = outputs[0].data_size(); + int work_per_thread = get_work_per_thread(outputs[0].dtype()); + size_t nthreads = ceildiv(outputs[0].data_size(), work_per_thread); MTL::Size group_dims( std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1); - MTL::Size grid_dims = large - ? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides()) + ? get_2d_grid_dims( + outputs[0].shape(), outputs[0].strides(), work_per_thread) : MTL::Size(nthreads, 1, 1); compute_encoder.dispatch_threads(grid_dims, group_dims); } else { diff --git a/mlx/backend/metal/copy.cpp b/mlx/backend/metal/copy.cpp index 3399201de..ee004359f 100644 --- a/mlx/backend/metal/copy.cpp +++ b/mlx/backend/metal/copy.cpp @@ -104,6 +104,8 @@ void copy_gpu_inplace( "[Copy::eval_gpu] Dynamic output offset requires GeneralGeneral copy"); } } + } else { + work_per_thread = get_work_per_thread(in.dtype()); } concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out)); auto kernel = dynamic ? get_dynamic_copy_kernel(d, kernel_name, in, out) @@ -165,13 +167,19 @@ void copy_gpu_inplace( MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); compute_encoder.dispatch_threads(grid_dims, group_dims); } else { - size_t nthreads = out.data_size(); + size_t nthreads = ceildiv(out.data_size(), work_per_thread); if (thread_group_size > nthreads) { thread_group_size = nthreads; } MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides()) - : MTL::Size(nthreads, 1, 1); + MTL::Size grid_dims; + if (large) { + compute_encoder.set_bytes(out.data_size(), 2); + grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread); + } else { + compute_encoder.set_bytes(out.data_size(), 2); + grid_dims = MTL::Size(nthreads, 1, 1); + } compute_encoder.dispatch_threads(grid_dims, group_dims); } } @@ -214,14 +222,21 @@ void fill_gpu(const array& val, array& out, const Stream& s) { compute_encoder.set_input_array(val, 0); compute_encoder.set_output_array(out, 1); + int work_per_thread = get_work_per_thread(val.dtype()); auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); - size_t nthreads = out.data_size(); + size_t nthreads = ceildiv(out.data_size(), work_per_thread); if (thread_group_size > nthreads) { thread_group_size = nthreads; } MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides()) - : MTL::Size(nthreads, 1, 1); + MTL::Size grid_dims; + if (large) { + compute_encoder.set_bytes(out.data_size(), 2); + grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread); + } else { + compute_encoder.set_bytes(out.data_size(), 2); + grid_dims = MTL::Size(nthreads, 1, 1); + } compute_encoder.dispatch_threads(grid_dims, group_dims); } diff --git a/mlx/backend/metal/kernels/binary.h b/mlx/backend/metal/kernels/binary.h index 91a02c818..ffc33ad82 100644 --- a/mlx/backend/metal/kernels/binary.h +++ b/mlx/backend/metal/kernels/binary.h @@ -9,64 +9,85 @@ template c[index] = Op()(a[0], b[0]); } -template +template ::n> [[kernel]] void binary_sv( device const T* a, device const T* b, device U* c, + constant uint& size, uint index [[thread_position_in_grid]]) { - c[index] = Op()(a[0], b[index]); + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + c[index + i] = Op()(a[0], b[index + i]); + } } -template +template ::n> [[kernel]] void binary_vs( device const T* a, device const T* b, device U* c, + constant uint& size, uint index [[thread_position_in_grid]]) { - c[index] = Op()(a[index], b[0]); + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + c[index + i] = Op()(a[index + i], b[0]); + } } -template +template ::n> [[kernel]] void binary_vv( device const T* a, device const T* b, device U* c, + constant uint& size, uint index [[thread_position_in_grid]]) { - c[index] = Op()(a[index], b[index]); + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + c[index + i] = Op()(a[index + i], b[index + i]); + } } -template +template ::n> [[kernel]] void binary_sv2( device const T* a, device const T* b, device U* c, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - int64_t offset = index.x + grid_dim.x * int64_t(index.y); - c[offset] = Op()(a[0], b[offset]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + c[offset + i] = Op()(a[0], b[offset + i]); + } } -template +template ::n> [[kernel]] void binary_vs2( device const T* a, device const T* b, device U* c, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - int64_t offset = index.x + grid_dim.x * int64_t(index.y); - c[offset] = Op()(a[offset], b[0]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + c[offset + i] = Op()(a[offset + i], b[0]); + } } -template +template ::n> [[kernel]] void binary_vv2( device const T* a, device const T* b, device U* c, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - int64_t offset = index.x + grid_dim.x * int64_t(index.y); - c[offset] = Op()(a[offset], b[offset]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + c[offset + i] = Op()(a[offset + i], b[offset + i]); + } } template diff --git a/mlx/backend/metal/kernels/binary_two.h b/mlx/backend/metal/kernels/binary_two.h index 8f6b3392d..e261d33c4 100644 --- a/mlx/backend/metal/kernels/binary_two.h +++ b/mlx/backend/metal/kernels/binary_two.h @@ -12,82 +12,103 @@ template d[index] = out[1]; } -template +template ::n> [[kernel]] void binary_sv( device const T* a, device const T* b, device U* c, device U* d, + constant uint& size, uint index [[thread_position_in_grid]]) { - auto out = Op()(a[0], b[index]); - c[index] = out[0]; - d[index] = out[1]; + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + auto out = Op()(a[0], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } } -template +template ::n> [[kernel]] void binary_vs( device const T* a, device const T* b, device U* c, device U* d, + constant uint& size, uint index [[thread_position_in_grid]]) { - auto out = Op()(a[index], b[0]); - c[index] = out[0]; - d[index] = out[1]; + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + auto out = Op()(a[index + i], b[0]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } } -template +template ::n> [[kernel]] void binary_vv( device const T* a, device const T* b, device U* c, device U* d, + constant uint& size, uint index [[thread_position_in_grid]]) { - auto out = Op()(a[index], b[index]); - c[index] = out[0]; - d[index] = out[1]; + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + auto out = Op()(a[index + i], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } } -template +template ::n> [[kernel]] void binary_sv2( device const T* a, device const T* b, device U* c, device U* d, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - auto out = Op()(a[0], b[offset]); - c[offset] = out[0]; - d[offset] = out[1]; + auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + auto out = Op()(a[0], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } } -template +template ::n> [[kernel]] void binary_vs2( device const T* a, device const T* b, device U* c, device U* d, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - auto out = Op()(a[offset], b[0]); - c[offset] = out[0]; - d[offset] = out[1]; + auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + auto out = Op()(a[offset + i], b[0]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } } -template +template ::n> [[kernel]] void binary_vv2( device const T* a, device const T* b, device U* c, device U* d, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - auto out = Op()(a[offset], b[offset]); - c[offset] = out[0]; - d[offset] = out[1]; + auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + auto out = Op()(a[offset + i], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } } template diff --git a/mlx/backend/metal/kernels/copy.h b/mlx/backend/metal/kernels/copy.h index b1367cf4f..2469d1f3d 100644 --- a/mlx/backend/metal/kernels/copy.h +++ b/mlx/backend/metal/kernels/copy.h @@ -1,39 +1,53 @@ // Copyright © 2024 Apple Inc. -template +template ::n> [[kernel]] void copy_s( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], + constant uint& size, uint index [[thread_position_in_grid]]) { - dst[index] = static_cast(src[0]); + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + dst[index + i] = static_cast(src[0]); + } } -template +template ::n> [[kernel]] void copy_v( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], + constant uint& size, uint index [[thread_position_in_grid]]) { - dst[index] = static_cast(src[index]); + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + dst[index + i] = static_cast(src[index + i]); + } } -template +template ::n> [[kernel]] void copy_s2( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - dst[offset] = static_cast(src[0]); + auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + dst[offset + i] = static_cast(src[0]); + } } -template +template ::n> [[kernel]] void copy_v2( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - dst[offset] = static_cast(src[offset]); + auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + dst[offset + i] = static_cast(src[offset + i]); + } } template diff --git a/mlx/backend/metal/kernels/ternary.h b/mlx/backend/metal/kernels/ternary.h index 4b3adcc80..5251dc7e9 100644 --- a/mlx/backend/metal/kernels/ternary.h +++ b/mlx/backend/metal/kernels/ternary.h @@ -1,25 +1,32 @@ // Copyright © 2024 Apple Inc. -template +template ::n> [[kernel]] void ternary_v( device const bool* a, device const T* b, device const T* c, device T* d, + constant uint& size, uint index [[thread_position_in_grid]]) { - d[index] = Op()(a[index], b[index], c[index]); + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + d[index + i] = Op()(a[index + i], b[index + i], c[index + i]); + } } -template +template ::n> [[kernel]] void ternary_v2( device const bool* a, device const T* b, device const T* c, device T* d, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - d[offset] = Op()(a[offset], b[offset], c[offset]); + auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]); + } } template diff --git a/mlx/backend/metal/kernels/unary.h b/mlx/backend/metal/kernels/unary.h index 69828599f..b5eaab2e9 100644 --- a/mlx/backend/metal/kernels/unary.h +++ b/mlx/backend/metal/kernels/unary.h @@ -1,21 +1,28 @@ // Copyright © 2024 Apple Inc. -template +template ::n> [[kernel]] void unary_v( device const T* in, device U* out, + constant uint& size, uint index [[thread_position_in_grid]]) { - out[index] = Op()(in[index]); + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + out[index + i] = Op()(in[index + i]); + } } -template +template ::n> [[kernel]] void unary_v2( device const T* in, device U* out, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - out[offset] = Op()(in[offset]); + auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + out[offset + i] = Op()(in[offset + i]); + } } template < diff --git a/mlx/backend/metal/kernels/utils.h b/mlx/backend/metal/kernels/utils.h index 1170d5576..c30d186b8 100644 --- a/mlx/backend/metal/kernels/utils.h +++ b/mlx/backend/metal/kernels/utils.h @@ -15,6 +15,14 @@ typedef half float16_t; +// Work per thread values for different types. The values here are expected to +// match get_work_per_thread in mlx/backend/metal/utils.h +template +struct WorkPerThread { + static_assert(sizeof(U) <= 8, "Type too large"); + static constexpr int constant n = 8 / sizeof(U); +}; + /////////////////////////////////////////////////////////////////////////////// // Type limits utils /////////////////////////////////////////////////////////////////////////////// diff --git a/mlx/backend/metal/ternary.cpp b/mlx/backend/metal/ternary.cpp index 36bfd3e2b..0b821151e 100644 --- a/mlx/backend/metal/ternary.cpp +++ b/mlx/backend/metal/ternary.cpp @@ -45,7 +45,7 @@ void ternary_op_gpu_inplace( work_per_thread = large ? 4 : 2; } else { large = out.data_size() > INT32_MAX; - work_per_thread = 1; + work_per_thread = get_work_per_thread(b.dtype()); } std::string kernel_name; if (topt == TernaryOpType::General) { @@ -106,13 +106,19 @@ void ternary_op_gpu_inplace( compute_encoder.dispatch_threads(grid_dims, group_dims); } else { // Launch a 1D or 2D grid of threads - size_t nthreads = out.data_size(); + size_t nthreads = ceildiv(out.data_size(), work_per_thread); if (thread_group_size > nthreads) { thread_group_size = nthreads; } MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides()) - : MTL::Size(nthreads, 1, 1); + MTL::Size grid_dims; + if (large) { + compute_encoder.set_bytes(out.data_size(), 4); + grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread); + } else { + compute_encoder.set_bytes(out.data_size(), 4); + grid_dims = MTL::Size(nthreads, 1, 1); + } compute_encoder.dispatch_threads(grid_dims, group_dims); } } diff --git a/mlx/backend/metal/unary.cpp b/mlx/backend/metal/unary.cpp index be43c41c2..368e693a9 100644 --- a/mlx/backend/metal/unary.cpp +++ b/mlx/backend/metal/unary.cpp @@ -34,18 +34,19 @@ void unary_op_gpu_inplace( }; auto [shape, strides] = maybe_collapse(); int ndim = shape.size(); - size_t nthreads = contig ? in.data_size() : in.size(); bool large; if (!contig) { large = in.data_size() > INT32_MAX || out.size() > INT32_MAX; } else { large = in.data_size() > UINT32_MAX; } - int work_per_thread = !contig && large ? 4 : 1; + int work_per_thread; std::string kernel_name; if (contig) { + work_per_thread = get_work_per_thread(in.dtype()); kernel_name = (large ? "v2" : "v"); } else { + work_per_thread = large ? 4 : 1; kernel_name = "gn" + std::to_string(work_per_thread); if (large) { kernel_name += "large"; @@ -75,12 +76,20 @@ void unary_op_gpu_inplace( MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); compute_encoder.dispatch_threads(grid_dims, group_dims); } else { + size_t nthreads = ceildiv(in.data_size(), work_per_thread); if (thread_group_size > nthreads) { thread_group_size = nthreads; } + MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides()) - : MTL::Size(nthreads, 1, 1); + MTL::Size grid_dims; + if (large) { + compute_encoder.set_bytes(in.data_size(), 2); + grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread); + } else { + compute_encoder.set_bytes(in.data_size(), 2); + grid_dims = MTL::Size(nthreads, 1, 1); + } compute_encoder.dispatch_threads(grid_dims, group_dims); } } diff --git a/mlx/backend/metal/utils.h b/mlx/backend/metal/utils.h index 079d15f17..f9245a6d6 100644 --- a/mlx/backend/metal/utils.h +++ b/mlx/backend/metal/utils.h @@ -84,4 +84,12 @@ void concatenate(std::string& acc, T first, Args... args) { concatenate(acc, args...); } +inline int get_work_per_thread(Dtype dtype) { + return std::max(1, 8 / dtype.size()); +} + +inline size_t ceildiv(size_t n, size_t m) { + return (n + m - 1) / m; +} + } // namespace mlx::core