From 884b4ed43b302fdbe9ec4352791e683a3a2d410b Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Wed, 21 Feb 2024 19:42:16 -0800 Subject: [PATCH] Fix threadgroup memory in arg reduce (#723) --- mlx/backend/metal/kernels/arg_reduce.metal | 12 +++++------- mlx/backend/metal/primitives.cpp | 2 -- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/mlx/backend/metal/kernels/arg_reduce.metal b/mlx/backend/metal/kernels/arg_reduce.metal index f153b920d..f24a32ce8 100644 --- a/mlx/backend/metal/kernels/arg_reduce.metal +++ b/mlx/backend/metal/kernels/arg_reduce.metal @@ -11,8 +11,6 @@ template struct IndexValPair { uint32_t index; U val; - - IndexValPair(uint32_t _index, U _val) : index(_index), val(_val) {} }; template @@ -65,10 +63,10 @@ struct ArgMax { template IndexValPair simd_shuffle_down(IndexValPair data, uint16_t delta) { - return IndexValPair( + return IndexValPair{ simd_shuffle_down(data.index, delta), simd_shuffle_down(data.val, delta) - ); + }; } @@ -82,7 +80,6 @@ template const device size_t& ndim [[buffer(5)]], const device size_t& axis_stride [[buffer(6)]], const device size_t& axis_size [[buffer(7)]], - threadgroup IndexValPair *local_data [[threadgroup(0)]], uint gid [[thread_position_in_grid]], uint lid [[thread_position_in_threadgroup]], uint lsize [[threads_per_threadgroup]], @@ -111,7 +108,9 @@ template auto in_idx = elem_to_loc(gid / lsize, shape, in_strides, ndim); auto out_idx = elem_to_loc(gid / lsize, shape, out_strides, ndim); - IndexValPair best(0, Op::init); + IndexValPair best{0, Op::init}; + + threadgroup IndexValPair local_data[32]; // Loop over the reduction axis in lsize*N_READS buckets for (uint r=0; r < ceildiv(axis_size, N_READS*lsize); r++) { @@ -172,7 +171,6 @@ template const device size_t& ndim [[buffer(5)]], \ const device size_t& axis_stride [[buffer(6)]], \ const device size_t& axis_size [[buffer(7)]], \ - threadgroup IndexValPair *local_data [[threadgroup(0)]], \ uint gid [[thread_position_in_grid]], \ uint lid [[thread_position_in_threadgroup]], \ uint lsize [[threads_per_threadgroup]], \ diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 38ec5993c..056bbbc80 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -430,8 +430,6 @@ void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { compute_encoder->setBytes(&ndim, sizeof(size_t), 5); compute_encoder->setBytes(&axis_stride, sizeof(size_t), 6); compute_encoder->setBytes(&axis_size, sizeof(size_t), 7); - compute_encoder->setThreadgroupMemoryLength( - simd_size * (sizeof(uint32_t) + in.itemsize()), 0); compute_encoder->dispatchThreads(grid_dims, group_dims); } }