diff --git a/mlx/backend/metal/kernels/logsumexp.h b/mlx/backend/metal/kernels/logsumexp.h index c746050b3..e7e0930f0 100644 --- a/mlx/backend/metal/kernels/logsumexp.h +++ b/mlx/backend/metal/kernels/logsumexp.h @@ -5,28 +5,33 @@ template const device T* in, device T* out, constant int& axis_size, - uint gid [[threadgroup_position_in_grid]], - uint _lid [[thread_position_in_threadgroup]], + uint2 gid [[threadgroup_position_in_grid]], + uint2 tid [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]], + uint2 _lid [[thread_position_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - int lid = _lid; + int lid = _lid.x; constexpr int SIMD_SIZE = 32; + constexpr int elem_per_group = SIMD_SIZE * 32 * N_READS; threadgroup AccT local_max[SIMD_SIZE]; threadgroup AccT local_normalizer[SIMD_SIZE]; AccT ld[N_READS]; - in += gid * size_t(axis_size) + lid * N_READS; - if (lid * N_READS + N_READS <= axis_size) { + const int axis_offset = tid.y * elem_per_group; + in += gid.x * size_t(axis_size) + lid * N_READS + axis_offset; + if (axis_offset + lid * N_READS + N_READS <= axis_size) { for (int i = 0; i < N_READS; i++) { ld[i] = AccT(in[i]); } } else { for (int i = 0; i < N_READS; i++) { - ld[i] = - ((lid * N_READS + i) < axis_size) ? AccT(in[i]) : Limits::min; + ld[i] = ((axis_offset + lid * N_READS + i) < axis_size) + ? AccT(in[i]) + : Limits::min; } } if (simd_group_id == 0) { @@ -55,6 +60,7 @@ template maxval = local_max[0]; // Compute exp(x_i - maxval) and store the partial sums in local_normalizer + out += gid.x * grid_dim.y + tid.y; AccT normalizer = 0; for (int i = 0; i < N_READS; i++) { normalizer += fast::exp(ld[i] - maxval); @@ -67,7 +73,7 @@ template if (simd_group_id == 0) { normalizer = simd_sum(local_normalizer[simd_lane_id]); if (simd_lane_id == 0) { - out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval); + out[0] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval); } } } diff --git a/mlx/backend/metal/logsumexp.cpp b/mlx/backend/metal/logsumexp.cpp index e53bc58d9..260289063 100644 --- a/mlx/backend/metal/logsumexp.cpp +++ b/mlx/backend/metal/logsumexp.cpp @@ -62,15 +62,37 @@ void LogSumExp::eval_gpu(const std::vector& inputs, array& out) { const int n_reads = 4; const int looped_limit = LOGSUMEXP_LOOPED_LIMIT; - std::string kernel_name = (axis_size > looped_limit) ? "looped_" : "block_"; + bool split = n_rows < 4 && axis_size > 4 * looped_limit; + bool looped = !split && axis_size > looped_limit; + std::string kernel_name = looped ? "looped_" : "block_"; kernel_name += "logsumexp_"; kernel_name += type_to_name(out); auto kernel = get_logsumexp_kernel(d, kernel_name, out); auto& compute_encoder = d.get_command_encoder(s.index); + if (split) { + auto tmp_size = ceildiv(axis_size, looped_limit); + auto tmp_shape = Shape{n_rows, static_cast(tmp_size)}; + array tmp(tmp_shape, in.dtype(), nullptr, {}); + tmp.set_data(allocator::malloc(tmp.nbytes())); + size_t threadgroup_size = 1024; + assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup()); + size_t n_threads = n_rows * threadgroup_size; + auto grid_dims = MTL::Size(n_threads, tmp_size, 1); + auto group_dims = MTL::Size(threadgroup_size, 1, 1); + compute_encoder.set_compute_pipeline_state(kernel); + compute_encoder.set_input_array(in, 0); + compute_encoder.set_output_array(tmp, 1); + compute_encoder.set_bytes(axis_size, 2); + compute_encoder.dispatch_threads(grid_dims, group_dims); + d.add_temporary(tmp, s.index); + in = tmp; + axis_size = tmp_size; + } + { MTL::Size grid_dims, group_dims; - if (axis_size <= looped_limit) { + if (!looped) { size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads; size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size; size_t threadgroup_size = simd_size * simds_needed; diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 7c4f3f8e3..a89124b88 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -760,6 +760,10 @@ class TestOps(mlx_tests.MLXTestCase): x = mx.broadcast_to(mx.random.uniform(shape=(2, 1, 8)), (2, 2, 8)) self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x))) + # Even larger + x = mx.random.uniform(shape=(4 * 4096 + 3,)) + self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x))) + def test_mean(self): x = mx.array( [