diff --git a/benchmarks/python/comparative/bench_mlx.py b/benchmarks/python/comparative/bench_mlx.py index 877fa4522..0821ccae6 100644 --- a/benchmarks/python/comparative/bench_mlx.py +++ b/benchmarks/python/comparative/bench_mlx.py @@ -144,6 +144,13 @@ def reduction(op, axis, x): mx.eval(ys) +def sum_and_add(axis, x, y): + z = x.sum(axis=axis, keepdims=True) + for i in range(50): + z = (z + y).sum(axis=axis, keepdims=True) + mx.eval(z) + + def softmax(axis, x): ys = [] for i in range(100): @@ -505,5 +512,8 @@ if __name__ == "__main__": elif args.benchmark == "selu": print(bench(selu, x)) + elif args.benchmark == "sum_and_add": + print(bench(sum_and_add, axis, *xs)) + else: raise ValueError("Unknown benchmark") diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index ed1b4b1fd..83554a6fe 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -319,16 +319,18 @@ MTL::ComputePipelineState* get_mb_sort_kernel( MTL::ComputePipelineState* get_reduce_init_kernel( metal::Device& d, const std::string& kernel_name, + const std::string& func_name, + const std::string& op_name, const array& out) { auto lib = d.get_library(kernel_name, [&]() { std::ostringstream kernel_source; - std::string op_type = op_name(out); - op_type[0] = std::toupper(op_name(out)[0]); + std::string op_type = op_name; + op_type[0] = std::toupper(op_name[0]); auto out_type = get_type_string(out.dtype()); std::string op = op_type + "<" + out_type + ">"; kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce(); kernel_source << get_template_definition( - kernel_name, "init_reduce", out_type, op); + kernel_name, func_name, out_type, op); return kernel_source.str(); }); return d.get_kernel(kernel_name, lib); diff --git a/mlx/backend/metal/kernels.h b/mlx/backend/metal/kernels.h index d8a258cb8..b5f9b0a92 100644 --- a/mlx/backend/metal/kernels.h +++ b/mlx/backend/metal/kernels.h @@ -79,6 +79,8 @@ MTL::ComputePipelineState* get_mb_sort_kernel( MTL::ComputePipelineState* get_reduce_init_kernel( metal::Device& d, const std::string& kernel_name, + const std::string& func_name, + const std::string& op_name, const array& out); MTL::ComputePipelineState* get_reduce_kernel( diff --git a/mlx/backend/metal/kernels/reduce.metal b/mlx/backend/metal/kernels/reduce.metal index 2b55bd9a6..d68045047 100644 --- a/mlx/backend/metal/kernels/reduce.metal +++ b/mlx/backend/metal/kernels/reduce.metal @@ -113,9 +113,12 @@ instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or) // special case bool with larger output type instantiate_all_reduce(sumbool_, bool, uint32_t, Sum) -#define instantiate_col_reduce_small(name, itype, otype, op, dim) \ - instantiate_kernel("col_reduce_small_" #dim "_reduce_" #name, \ - col_reduce_small, \ +#define instantiate_col_reduce_small(name, itype, otype, op, dim) \ + instantiate_kernel("col_reduce_small_" #dim "_reduce_" #name, \ + col_reduce_small, \ + itype, otype, op, dim) \ + instantiate_kernel("col_reduce_longcolumn_" #dim "_reduce_" #name, \ + col_reduce_longcolumn, \ itype, otype, op, dim) #define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \ @@ -123,9 +126,14 @@ instantiate_all_reduce(sumbool_, bool, uint32_t, Sum) col_reduce_looped, \ itype, otype, op, dim, bm, bn) +#define instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, bm, bn) \ + instantiate_kernel("col_reduce_2pass_" #dim "_" #bm "_" #bn "_reduce_" #name, \ + col_reduce_2pass, \ + itype, otype, op, dim, bm, bn) + #define instantiate_col_reduce_looped(name, itype, otype, op, dim) \ - instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 8, 128) \ - instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 32, 32) + instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 32, 32) \ + instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, 32, 32) #define instantiate_col_reduce_general(name, itype, otype, op) \ instantiate_col_reduce_small(name, itype, otype, op, 0) \ diff --git a/mlx/backend/metal/kernels/reduction/reduce_col.h b/mlx/backend/metal/kernels/reduction/reduce_col.h index 52e763ddc..735e80afe 100644 --- a/mlx/backend/metal/kernels/reduction/reduce_col.h +++ b/mlx/backend/metal/kernels/reduction/reduce_col.h @@ -1,11 +1,6 @@ // Copyright © 2023-2024 Apple Inc. -template < - typename T, - typename U, - typename Op, - int NDIMS, - int N_READS = REDUCE_N_READS> +template [[kernel]] void col_reduce_small( const device T* in [[buffer(0)]], device U* out [[buffer(1)]], @@ -20,170 +15,128 @@ template < const constant size_t& non_col_reductions [[buffer(10)]], uint3 gid [[threadgroup_position_in_grid]], uint3 gsize [[threadgroups_per_grid]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint3 tid [[thread_position_in_grid]], - uint3 tsize [[threads_per_grid]]) { + uint3 lid [[thread_position_in_threadgroup]], + uint3 lsize [[threads_per_threadgroup]]) { + constexpr int n_reads = 4; Op op; looped_elem_to_loc loop; const device T* row; - // Case 1: Small row small column - if (reduction_size * non_col_reductions < 64 && reduction_stride < 32) { - U totals[31]; - for (int i = 0; i < 31; i++) { - totals[i] = Op::init; + U totals[n_reads]; + for (int i = 0; i < n_reads; i++) { + totals[i] = Op::init; + } + + size_t column = size_t(gid.x) * lsize.x * n_reads + lid.x * n_reads; + if (column >= reduction_stride) { + return; + } + bool safe = column + n_reads <= reduction_stride; + + size_t out_idx = gid.y + gsize.y * size_t(gid.z); + size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim); + in += in_idx + column; + + size_t total_rows = non_col_reductions * reduction_size; + loop.next(lid.y, reduce_shape, reduce_strides); + for (size_t r = lid.y; r < total_rows; r += lsize.y) { + row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); + if (safe) { + for (int i = 0; i < n_reads; i++) { + totals[i] = op(static_cast(row[i]), totals[i]); + } + } else { + U vals[n_reads]; + for (int i = 0; i < n_reads; i++) { + vals[i] = + (column + i < reduction_stride) ? static_cast(row[i]) : op.init; + } + for (int i = 0; i < n_reads; i++) { + totals[i] = op(vals[i], totals[i]); + } } + loop.next(lsize.y, reduce_shape, reduce_strides); + } - short stride = reduction_stride; - short size = reduction_size; - short blocks = stride / N_READS; - short extra = stride - blocks * N_READS; - - size_t out_idx = tid.x + tsize.y * size_t(tid.y); - in += elem_to_loc(out_idx, shape, strides, ndim); - - for (uint r = 0; r < non_col_reductions; r++) { - row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); - - for (short i = 0; i < size; i++) { - for (short j = 0; j < blocks; j++) { - for (short k = 0; k < N_READS; k++) { - totals[j * N_READS + k] = - op(totals[j * N_READS + k], - static_cast(row[i * stride + j * N_READS + k])); - } - } - for (short k = 0; k < extra; k++) { - totals[blocks * N_READS + k] = - op(totals[blocks * N_READS + k], - static_cast(row[i * stride + blocks * N_READS + k])); + if (lsize.y > 1) { + // lsize.y should be <= 8 + threadgroup U shared_vals[32 * 8 * n_reads]; + for (int i = 0; i < n_reads; i++) { + shared_vals[lid.y * lsize.x * n_reads + lid.x * n_reads + i] = totals[i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (lid.y == 0) { + for (int i = 0; i < n_reads; i++) { + totals[i] = shared_vals[lid.x * n_reads + i]; + } + for (uint j = 1; j < lsize.y; j++) { + for (int i = 0; i < n_reads; i++) { + totals[i] = + op(shared_vals[j * lsize.x * n_reads + lid.x * n_reads + i], + totals[i]); } } - - loop.next(reduce_shape, reduce_strides); - } - out += out_idx * reduction_stride; - for (short j = 0; j < stride; j++) { - out[j] = totals[j]; } } - // Case 2: Long row small column - else if (reduction_size * non_col_reductions < 32) { - U totals[N_READS]; - for (int i = 0; i < N_READS; i++) { - totals[i] = Op::init; - } - - short size = reduction_size; - size_t offset = size_t(tid.x) * N_READS; - bool safe = offset + N_READS <= reduction_stride; - short extra = reduction_stride - offset; - - size_t out_idx = tid.y + tsize.z * size_t(tid.z); - in += elem_to_loc(out_idx, shape, strides, ndim) + offset; - - for (uint r = 0; r < non_col_reductions; r++) { - row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); - - if (safe) { - for (short i = 0; i < size; i++) { - for (short j = 0; j < N_READS; j++) { - totals[j] = - op(static_cast(row[i * reduction_stride + j]), totals[j]); - } - } - } else { - for (short i = 0; i < size; i++) { - for (short j = 0; j < extra; j++) { - totals[j] = - op(static_cast(row[i * reduction_stride + j]), totals[j]); - } - } - } - - loop.next(reduce_shape, reduce_strides); - } - out += out_idx * reduction_stride + offset; + if (lid.y == 0) { + out += out_idx * reduction_stride + column; if (safe) { - for (short i = 0; i < N_READS; i++) { + for (int i = 0; i < n_reads; i++) { out[i] = totals[i]; } } else { - for (short i = 0; i < extra; i++) { + for (int i = 0; column + i < reduction_stride; i++) { out[i] = totals[i]; } } } +} - // Case 3: Long row medium column - else { - threadgroup U shared_vals[1024]; - U totals[N_READS]; - for (int i = 0; i < N_READS; i++) { - totals[i] = Op::init; - } - - short stride = reduction_stride; - short lid = simd_group_id * simd_size + simd_lane_id; - short2 tile((stride + N_READS - 1) / N_READS, 32); - short2 offset((lid % tile.x) * N_READS, lid / tile.x); - short sm_stride = tile.x * N_READS; - bool safe = offset.x + N_READS <= stride; - - size_t out_idx = gid.y + gsize.y * size_t(gid.z); - in += elem_to_loc(out_idx, shape, strides, ndim) + offset.x; - - // Read cooperatively and contiguously and aggregate the partial results. - size_t total = non_col_reductions * reduction_size; - loop.next(offset.y, reduce_shape, reduce_strides); - for (size_t r = offset.y; r < total; r += simd_size) { - row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); - - if (safe) { - for (int i = 0; i < N_READS; i++) { - totals[i] = op(static_cast(row[i]), totals[i]); - } - } else { - U vals[N_READS]; - for (int i = 0; i < N_READS; i++) { - vals[i] = (offset.x + i < stride) ? static_cast(row[i]) : op.init; - } - for (int i = 0; i < N_READS; i++) { - totals[i] = op(vals[i], totals[i]); - } - } - - loop.next(simd_size, reduce_shape, reduce_strides); - } - - // Each thread holds N_READS partial results but the simdgroups are not - // aligned to do the reduction across the simdgroup so we write our results - // in the shared memory and read them back according to the simdgroup. - for (int i = 0; i < N_READS; i++) { - shared_vals[offset.y * sm_stride + offset.x + i] = totals[i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - for (int i = 0; i < N_READS; i++) { - totals[i] = op.simd_reduce( - shared_vals[simd_lane_id * sm_stride + simd_group_id * N_READS + i]); - } - - // Write the output. - if (simd_lane_id == 0) { - short column = simd_group_id * N_READS; - out += out_idx * reduction_stride + column; - if (column + N_READS <= stride) { - for (int i = 0; i < N_READS; i++) { - out[i] = totals[i]; - } - } else { - for (int i = 0; column + i < stride; i++) { - out[i] = totals[i]; - } - } +template +[[kernel]] void col_reduce_longcolumn( + const device T* in [[buffer(0)]], + device U* out [[buffer(1)]], + const constant size_t& reduction_size [[buffer(2)]], + const constant size_t& reduction_stride [[buffer(3)]], + const constant int* shape [[buffer(4)]], + const constant size_t* strides [[buffer(5)]], + const constant int& ndim [[buffer(6)]], + const constant int* reduce_shape [[buffer(7)]], + const constant size_t* reduce_strides [[buffer(8)]], + const constant int& reduce_ndim [[buffer(9)]], + const constant size_t& non_col_reductions [[buffer(10)]], + const constant size_t& out_size [[buffer(11)]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 lsize [[threads_per_threadgroup]]) { + Op op; + looped_elem_to_loc loop; + const device T* row; + + size_t out_idx = gid.x + gsize.x * size_t(gid.y); + size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim); + in += in_idx + lid.x; + + U total = Op::init; + size_t total_rows = non_col_reductions * reduction_size; + loop.next(gid.z * lsize.y + lid.y, reduce_shape, reduce_strides); + for (size_t r = gid.z * lsize.y + lid.y; r < total_rows; + r += lsize.y * gsize.z) { + row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); + total = op(static_cast(*row), total); + loop.next(lsize.y * gsize.z, reduce_shape, reduce_strides); + } + + threadgroup U shared_vals[32 * 32]; + shared_vals[lid.y * lsize.x + lid.x] = total; + threadgroup_barrier(mem_flags::mem_threadgroup); + if (lid.y == 0) { + for (uint i = 1; i < lsize.y; i++) { + total = op(total, shared_vals[i * lsize.x + lid.x]); } + out[gid.z * out_size + out_idx * reduction_stride + lid.x] = total; } } @@ -216,7 +169,7 @@ template uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { Op op; - constexpr int n_simdgroups = 4; + constexpr int n_simdgroups = 8; constexpr short tgp_size = n_simdgroups * simd_size; constexpr short n_reads = (BM * BN) / tgp_size; constexpr short n_read_blocks = BN / n_reads; @@ -329,3 +282,103 @@ template } } } + +template +[[kernel]] void col_reduce_2pass( + const device T* in [[buffer(0)]], + device U* out [[buffer(1)]], + const constant size_t& reduction_size [[buffer(2)]], + const constant size_t& reduction_stride [[buffer(3)]], + const constant int* shape [[buffer(4)]], + const constant size_t* strides [[buffer(5)]], + const constant int& ndim [[buffer(6)]], + const constant int* reduce_shape [[buffer(7)]], + const constant size_t* reduce_strides [[buffer(8)]], + const constant int& reduce_ndim [[buffer(9)]], + const constant size_t& non_col_reductions [[buffer(10)]], + const constant size_t& out_size [[buffer(11)]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + Op op; + constexpr int n_simdgroups = 8; + constexpr short tgp_size = n_simdgroups * simd_size; + constexpr short n_reads = (BM * BN) / tgp_size; + constexpr short n_read_blocks = BN / n_reads; + constexpr int n_outputs = BN / n_simdgroups; + constexpr short outer_blocks = 32; + static_assert(BM == 32, "BM should be equal to 32"); + + threadgroup U shared_vals[BN * BM]; + U totals[n_reads]; + looped_elem_to_loc loop; + const device T* row; + + for (int i = 0; i < n_reads; i++) { + totals[i] = Op::init; + } + + short lid = simd_group_id * simd_size + simd_lane_id; + short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks); + size_t column = BN * gid.x + offset.x; + bool safe = column + n_reads <= reduction_stride; + + size_t full_idx = gid.y + gsize.y * size_t(gid.z); + size_t block_idx = full_idx / out_size; + size_t out_idx = full_idx % out_size; + size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim); + in += in_idx + column; + + size_t total = non_col_reductions * reduction_size; + loop.next(offset.y + block_idx * BM, reduce_shape, reduce_strides); + for (size_t r = offset.y + block_idx * BM; r < total; + r += outer_blocks * BM) { + row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); + + if (safe) { + for (int i = 0; i < n_reads; i++) { + totals[i] = op(static_cast(row[i]), totals[i]); + } + } else { + U vals[n_reads]; + for (int i = 0; i < n_reads; i++) { + vals[i] = + (column + i < reduction_stride) ? static_cast(row[i]) : op.init; + } + for (int i = 0; i < n_reads; i++) { + totals[i] = op(vals[i], totals[i]); + } + } + + loop.next(outer_blocks * BM, reduce_shape, reduce_strides); + } + + // We can use a simd reduction to accumulate across BM so each thread writes + // the partial output to SM and then each simdgroup does BN / n_simdgroups + // accumulations. + for (int i = 0; i < n_reads; i++) { + shared_vals[offset.y * BN + offset.x + i] = totals[i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + short2 out_offset(simd_group_id * n_outputs, simd_lane_id); + for (int i = 0; i < n_outputs; i++) { + totals[i] = + op.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]); + } + + // Write the output. + if (simd_lane_id == 0) { + size_t out_column = BN * gid.x + out_offset.x; + out += full_idx * reduction_stride + out_column; + if (out_column + n_outputs <= reduction_stride) { + for (int i = 0; i < n_outputs; i++) { + out[i] = totals[i]; + } + } else { + for (int i = 0; out_column + i < reduction_stride; i++) { + out[i] = totals[i]; + } + } + } +} diff --git a/mlx/backend/metal/nojit_kernels.cpp b/mlx/backend/metal/nojit_kernels.cpp index 006e2ae46..b5d9f5fa2 100644 --- a/mlx/backend/metal/nojit_kernels.cpp +++ b/mlx/backend/metal/nojit_kernels.cpp @@ -97,6 +97,8 @@ MTL::ComputePipelineState* get_mb_sort_kernel( MTL::ComputePipelineState* get_reduce_init_kernel( metal::Device& d, const std::string& kernel_name, + const std::string&, + const std::string&, const array&) { return d.get_kernel(kernel_name); } diff --git a/mlx/backend/metal/reduce.cpp b/mlx/backend/metal/reduce.cpp index d8906a819..6adab0824 100644 --- a/mlx/backend/metal/reduce.cpp +++ b/mlx/backend/metal/reduce.cpp @@ -141,6 +141,20 @@ struct ColReduceArgs { ndim = shape.size(); } + /** + * Create the col reduce arguments for reducing the 1st axis of the row + * contiguous intermediate array. + */ + ColReduceArgs(const array& intermediate) { + assert(intermediate.flags().row_contiguous); + + reduction_size = intermediate.shape(0); + reduction_stride = intermediate.size() / reduction_size; + non_col_reductions = 1; + reduce_ndim = 0; + ndim = 0; + } + void encode(CommandEncoder& compute_encoder) { // Push 0s to avoid encoding empty vectors. if (reduce_ndim == 0) { @@ -231,8 +245,10 @@ void init_reduce( CommandEncoder& compute_encoder, metal::Device& d, const Stream& s) { - auto kernel = get_reduce_init_kernel( - d, "init_reduce_" + op_name + type_to_name(out), out); + std::ostringstream kname; + const std::string func_name = "init_reduce"; + kname << func_name << "_" << op_name << type_to_name(out); + auto kernel = get_reduce_init_kernel(d, kname.str(), func_name, op_name, out); size_t nthreads = out.size(); MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); @@ -251,8 +267,7 @@ void all_reduce_dispatch( const std::string& op_name, CommandEncoder& compute_encoder, metal::Device& d, - const Stream& s, - std::vector& copies) { + const Stream& s) { // Set the kernel std::ostringstream kname; const std::string func_name = "all_reduce"; @@ -293,7 +308,7 @@ void all_reduce_dispatch( // Allocate an intermediate tensor to hold results if needed array intermediate({n_rows}, out.dtype(), nullptr, {}); intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes())); - copies.push_back(intermediate); + d.add_temporary(intermediate, s.index); // 1st pass size_t row_size = (in_size + n_rows - 1) / n_rows; @@ -469,39 +484,11 @@ void strided_reduce_small( // Figure out the grid dims MTL::Size grid_dims, group_dims; - // Case 1: Small row small column - if (args.reduction_size * args.non_col_reductions < 64 && - args.reduction_stride < 32) { - grid_dims = output_grid_for_col_reduce(out, args); - int threadgroup_size = (grid_dims.width > 128) ? 128 : grid_dims.width; - group_dims = MTL::Size(threadgroup_size, 1, 1); - } + // Prepare the arguments for the kernel + args.reduce_shape.push_back(args.reduction_size); + args.reduce_strides.push_back(args.reduction_stride); + args.reduce_ndim++; - // Case 2: Long row small column - else if (args.reduction_size * args.non_col_reductions < 32) { - auto out_grid_dims = output_grid_for_col_reduce(out, args); - int threads_x = - (args.reduction_stride + REDUCE_N_READS - 1) / REDUCE_N_READS; - int threadgroup_x = std::min(threads_x, 128); - grid_dims = MTL::Size(threads_x, out_grid_dims.width, out_grid_dims.height); - group_dims = MTL::Size(threadgroup_x, 1, 1); - } - - // Case 3: Long row medium column - else { - args.reduce_shape.push_back(args.reduction_size); - args.reduce_strides.push_back(args.reduction_stride); - args.reduce_ndim++; - int simdgroups = - (args.reduction_stride + REDUCE_N_READS - 1) / REDUCE_N_READS; - int threadgroup_size = simdgroups * 32; - auto out_grid_dims = output_grid_for_col_reduce(out, args); - grid_dims = - MTL::Size(threadgroup_size, out_grid_dims.width, out_grid_dims.height); - group_dims = MTL::Size(threadgroup_size, 1, 1); - } - - // Set the kernel int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0; std::ostringstream kname; const std::string func_name = "col_reduce_small"; @@ -510,10 +497,113 @@ void strided_reduce_small( get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n); compute_encoder->setComputePipelineState(kernel); + const int n_reads = 4; + size_t reduction_stride_blocks = + (args.reduction_stride + n_reads - 1) / n_reads; + size_t total = args.reduction_size * args.non_col_reductions; + size_t threadgroup_x = std::min(reduction_stride_blocks, 32ul); + size_t threadgroup_y = std::min( + 8ul, + std::min(kernel->maxTotalThreadsPerThreadgroup() / threadgroup_x, total)); + + group_dims = MTL::Size(threadgroup_x, threadgroup_y, 1); + grid_dims = output_grid_for_col_reduce(out, args); + grid_dims = MTL::Size( + (reduction_stride_blocks + threadgroup_x - 1) / threadgroup_x, + grid_dims.width, + grid_dims.height); + // Launch compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(out, 1); args.encode(compute_encoder); + compute_encoder.dispatchThreadgroups(grid_dims, group_dims); +} + +void strided_reduce_longcolumn( + const array& in, + array& out, + const std::string& op_name, + ColReduceArgs& args, + CommandEncoder& compute_encoder, + metal::Device& d, + const Stream& s) { + size_t total_reduction_size = args.reduction_size * args.non_col_reductions; + size_t outer_blocks = 32; + if (total_reduction_size >= 32768) { + outer_blocks = 128; + } + + // Prepare the temporary accumulator + std::vector intermediate_shape; + intermediate_shape.reserve(out.ndim() + 1); + intermediate_shape.push_back(outer_blocks); + intermediate_shape.insert( + intermediate_shape.end(), out.shape().begin(), out.shape().end()); + array intermediate(std::move(intermediate_shape), out.dtype(), nullptr, {}); + intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes())); + d.add_temporary(intermediate, s.index); + + // Prepare the arguments for the kernel + args.reduce_shape.push_back(args.reduction_size); + args.reduce_strides.push_back(args.reduction_stride); + args.reduce_ndim++; + + // Figure out the grid dims + size_t out_size = out.size(); + size_t threadgroup_x = args.reduction_stride; + size_t threadgroup_y = + (args.non_col_reductions * args.reduction_size + outer_blocks - 1) / + outer_blocks; + threadgroup_y = std::min(32ul, threadgroup_y); + + auto out_grid_size = output_grid_for_col_reduce(out, args); + MTL::Size grid_dims(out_grid_size.width, out_grid_size.height, outer_blocks); + MTL::Size group_dims(threadgroup_x, threadgroup_y, 1); + + // Set the kernel + int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0; + std::ostringstream kname; + const std::string func_name = "col_reduce_longcolumn"; + kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in); + auto kernel = + get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n); + compute_encoder->setComputePipelineState(kernel); + + // Launch + compute_encoder.set_input_array(in, 0); + compute_encoder.set_output_array(intermediate, 1); + args.encode(compute_encoder); + compute_encoder->setBytes(&out_size, sizeof(size_t), 11); + compute_encoder.dispatchThreadgroups(grid_dims, group_dims); + + // Make the 2nd pass arguments and grid_dims + ColReduceArgs second_args(intermediate); + second_args.reduce_shape.push_back(outer_blocks); + second_args.reduce_strides.push_back(out.size()); + second_args.reduce_ndim++; + int BN = 32; + grid_dims = MTL::Size(256 * ((out.size() + BN - 1) / BN), 1, 1); + group_dims = MTL::Size(256, 1, 1); + + // Set the 2nd kernel + const std::string second_kernel = "col_reduce_looped_1_32_32_reduce_" + + op_name + type_to_name(intermediate); + kernel = get_reduce_kernel( + d, + second_kernel, + "col_reduce_looped", + op_name, + intermediate, + out, + 1, + 32, + 32); + compute_encoder->setComputePipelineState(kernel); + + compute_encoder.set_input_array(intermediate, 0); + compute_encoder.set_output_array(out, 1); + second_args.encode(compute_encoder); compute_encoder.dispatchThreads(grid_dims, group_dims); } @@ -532,9 +622,9 @@ void strided_reduce_looped( // Figure out the grid dims auto out_grid_size = output_grid_for_col_reduce(out, args); - int BN = (args.reduction_stride <= 1024) ? 32 : 128; + int BN = 32; int BM = 1024 / BN; - int threadgroup_size = 4 * 32; + int threadgroup_size = 8 * 32; MTL::Size grid_dims( threadgroup_size * ((args.reduction_stride + BN - 1) / BN), out_grid_size.width, @@ -558,6 +648,87 @@ void strided_reduce_looped( compute_encoder.dispatchThreads(grid_dims, group_dims); } +void strided_reduce_2pass( + const array& in, + array& out, + const std::string& op_name, + ColReduceArgs& args, + CommandEncoder& compute_encoder, + metal::Device& d, + const Stream& s) { + // Prepare the temporary accumulator + std::vector intermediate_shape; + intermediate_shape.reserve(out.ndim() + 1); + intermediate_shape.push_back(32); + intermediate_shape.insert( + intermediate_shape.end(), out.shape().begin(), out.shape().end()); + array intermediate(std::move(intermediate_shape), out.dtype(), nullptr, {}); + intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes())); + d.add_temporary(intermediate, s.index); + + // Prepare the arguments for the kernel + args.reduce_shape.push_back(args.reduction_size); + args.reduce_strides.push_back(args.reduction_stride); + args.reduce_ndim++; + + // Figure out the grid dims + size_t out_size = out.size() / args.reduction_stride; + auto out_grid_size = output_grid_for_col_reduce(out, args); + int outer_blocks = 32; + int BN = 32; + int BM = 1024 / BN; + int threadgroup_size = 8 * 32; + MTL::Size grid_dims( + threadgroup_size * ((args.reduction_stride + BN - 1) / BN), + out_grid_size.width * outer_blocks, + out_grid_size.height); + MTL::Size group_dims(threadgroup_size, 1, 1); + + // Set the kernel + int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0; + std::ostringstream kname; + const std::string func_name = "col_reduce_2pass"; + kname << func_name << "_" << n << "_" << BM << "_" << BN << "_reduce_" + << op_name << type_to_name(in); + auto kernel = + get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n, BM, BN); + compute_encoder->setComputePipelineState(kernel); + + // Launch + compute_encoder.set_input_array(in, 0); + compute_encoder.set_output_array(intermediate, 1); + args.encode(compute_encoder); + compute_encoder->setBytes(&out_size, sizeof(size_t), 11); + compute_encoder.dispatchThreads(grid_dims, group_dims); + + // Make the 2nd pass arguments and grid_dims + ColReduceArgs second_args(intermediate); + second_args.reduce_shape.push_back(outer_blocks); + second_args.reduce_strides.push_back(out.size()); + second_args.reduce_ndim++; + grid_dims = MTL::Size(threadgroup_size * ((out.size() + BN - 1) / BN), 1, 1); + + // Set the 2nd kernel + const std::string second_kernel = "col_reduce_looped_1_32_32_reduce_" + + op_name + type_to_name(intermediate); + kernel = get_reduce_kernel( + d, + second_kernel, + "col_reduce_looped", + op_name, + intermediate, + out, + 1, + 32, + 32); + compute_encoder->setComputePipelineState(kernel); + + compute_encoder.set_input_array(intermediate, 0); + compute_encoder.set_output_array(out, 1); + second_args.encode(compute_encoder); + compute_encoder.dispatchThreads(grid_dims, group_dims); +} + void strided_reduce_general_dispatch( const array& in, array& out, @@ -570,11 +741,23 @@ void strided_reduce_general_dispatch( // Prepare the arguments for the kernel ColReduceArgs args(in, plan, axes); - if (args.reduction_stride < 32 || - args.reduction_size * args.non_col_reductions < 32) { + // Small column + if (args.reduction_size * args.non_col_reductions < 32) { return strided_reduce_small(in, out, op_name, args, compute_encoder, d, s); } + // Long column but small row + if (args.reduction_stride < 32 && + args.reduction_size * args.non_col_reductions >= 1024) { + return strided_reduce_longcolumn( + in, out, op_name, args, compute_encoder, d, s); + } + + if (args.reduction_size * args.non_col_reductions > 256 && + out.size() / 32 < 1024) { + return strided_reduce_2pass(in, out, op_name, args, compute_encoder, d, s); + } + return strided_reduce_looped(in, out, op_name, args, compute_encoder, d, s); } @@ -620,7 +803,6 @@ void Reduce::eval_gpu(const std::vector& inputs, array& out) { // Reduce if (in.size() > 0) { - std::vector copies; ReductionPlan plan = get_reduction_plan(in, axes_); // If it is a general reduce then copy the input to a contiguous array and @@ -632,7 +814,7 @@ void Reduce::eval_gpu(const std::vector& inputs, array& out) { if (plan.type == GeneralReduce) { array in_copy(in.shape(), in.dtype(), nullptr, {}); copy_gpu(in, in_copy, CopyType::General, s); - copies.push_back(in_copy); + d.add_temporary(in_copy, s.index); in = in_copy; plan = get_reduction_plan(in, axes_); } @@ -640,7 +822,7 @@ void Reduce::eval_gpu(const std::vector& inputs, array& out) { // Reducing over everything and the data is all there no broadcasting or // slicing etc. if (plan.type == ContiguousAllReduce) { - all_reduce_dispatch(in, out, op_name, compute_encoder, d, s, copies); + all_reduce_dispatch(in, out, op_name, compute_encoder, d, s); } // At least the last dimension is row contiguous and we are reducing over @@ -659,8 +841,6 @@ void Reduce::eval_gpu(const std::vector& inputs, array& out) { strided_reduce_general_dispatch( in, out, op_name, plan, axes_, compute_encoder, d, s); } - - d.add_temporaries(std::move(copies), s.index); } // Nothing to reduce just initialize the output diff --git a/mlx/backend/metal/reduce.h b/mlx/backend/metal/reduce.h index 4d2829a9b..a997d7e24 100644 --- a/mlx/backend/metal/reduce.h +++ b/mlx/backend/metal/reduce.h @@ -16,8 +16,7 @@ void all_reduce_dispatch( const std::string& op_name, CommandEncoder& compute_encoder, metal::Device& d, - const Stream& s, - std::vector& copies); + const Stream& s); void row_reduce_general_dispatch( const array& in,