From 6686e61ca443f43e6a86fbbad5ef40b258b0821b Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Mon, 4 Mar 2024 19:09:51 -0800 Subject: [PATCH] Reduce update (#783) * Split reduction files to reduce compile times * Add small and medium axis size specializations for row reductions * Add non-row-reduction options for small and med kernels --- benchmarks/python/comparative/bench_mlx.py | 8 +- benchmarks/python/comparative/bench_torch.py | 8 +- mlx/backend/metal/kernels/CMakeLists.txt | 11 +- mlx/backend/metal/kernels/reduce.metal | 619 ------------------ .../reduction/kernels/reduce_all.metal | 185 ++++++ .../reduction/kernels/reduce_col.metal | 184 ++++++ .../reduction/kernels/reduce_init.metal | 33 + .../reduction/kernels/reduce_row.metal | 369 +++++++++++ .../kernels/{reduce.h => reduction/ops.h} | 2 +- .../metal/kernels/reduction/reduce_inst.h | 71 ++ mlx/backend/metal/kernels/reduction/utils.h | 14 + mlx/backend/metal/kernels/scatter.metal | 2 +- mlx/backend/metal/reduce.cpp | 110 +++- 13 files changed, 949 insertions(+), 667 deletions(-) delete mode 100644 mlx/backend/metal/kernels/reduce.metal create mode 100644 mlx/backend/metal/kernels/reduction/kernels/reduce_all.metal create mode 100644 mlx/backend/metal/kernels/reduction/kernels/reduce_col.metal create mode 100644 mlx/backend/metal/kernels/reduction/kernels/reduce_init.metal create mode 100644 mlx/backend/metal/kernels/reduction/kernels/reduce_row.metal rename mlx/backend/metal/kernels/{reduce.h => reduction/ops.h} (98%) create mode 100644 mlx/backend/metal/kernels/reduction/reduce_inst.h create mode 100644 mlx/backend/metal/kernels/reduction/utils.h diff --git a/benchmarks/python/comparative/bench_mlx.py b/benchmarks/python/comparative/bench_mlx.py index 51cd0cfb1..877fa4522 100644 --- a/benchmarks/python/comparative/bench_mlx.py +++ b/benchmarks/python/comparative/bench_mlx.py @@ -380,10 +380,6 @@ if __name__ == "__main__": if len(args.axis) > 1: args.axis.pop(0) - if args.print_pid: - print(os.getpid()) - input("Press enter to run") - if args.cpu: mx.set_default_device(mx.cpu) else: @@ -406,6 +402,10 @@ if __name__ == "__main__": x = xs[0] axis = args.axis[0] + if args.print_pid: + print(os.getpid()) + input("Press enter to run") + if args.benchmark == "matmul_square": print(bench(matmul_square, x)) diff --git a/benchmarks/python/comparative/bench_torch.py b/benchmarks/python/comparative/bench_torch.py index 1e8649537..a83e1b503 100644 --- a/benchmarks/python/comparative/bench_torch.py +++ b/benchmarks/python/comparative/bench_torch.py @@ -331,10 +331,6 @@ if __name__ == "__main__": if len(args.axis) > 1: args.axis.pop(0) - if args.print_pid: - print(os.getpid()) - input("Press enter to run") - torch.set_num_threads(1) device = "cpu" if args.cpu else "mps" @@ -354,6 +350,10 @@ if __name__ == "__main__": x = xs[0] axis = args.axis[0] + if args.print_pid: + print(os.getpid()) + input("Press enter to run") + if args.benchmark == "matmul_square": print(bench(matmul_square, x)) diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index b3721b6d4..23979b8ac 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -8,7 +8,6 @@ set( ${CMAKE_CURRENT_SOURCE_DIR}/defines.h ${CMAKE_CURRENT_SOURCE_DIR}/erf.h ${CMAKE_CURRENT_SOURCE_DIR}/indexing.h - ${CMAKE_CURRENT_SOURCE_DIR}/reduce.h ${CMAKE_CURRENT_SOURCE_DIR}/unary.h ${CMAKE_CURRENT_SOURCE_DIR}/utils.h ) @@ -24,7 +23,6 @@ set( "gemv" "quantized" "random" - "reduce" "rope" "scan" "softmax" @@ -68,6 +66,15 @@ foreach(KERNEL ${STEEL_KERNELS}) set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR}) endforeach() +file(GLOB_RECURSE REDUCE_KERNELS ${CMAKE_CURRENT_SOURCE_DIR}/reduction/*.metal) +file(GLOB_RECURSE REDUCE_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/reduction/*.h) + +foreach(KERNEL ${REDUCE_KERNELS}) + cmake_path(GET KERNEL STEM TARGET) + build_kernel_base(${TARGET} ${KERNEL} "${REDUCE_HEADERS}") + set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR}) +endforeach() + add_custom_command( OUTPUT ${MLX_METAL_PATH}/mlx.metallib COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o ${MLX_METAL_PATH}/mlx.metallib diff --git a/mlx/backend/metal/kernels/reduce.metal b/mlx/backend/metal/kernels/reduce.metal deleted file mode 100644 index ee00f48ff..000000000 --- a/mlx/backend/metal/kernels/reduce.metal +++ /dev/null @@ -1,619 +0,0 @@ -// Copyright © 2023 Apple Inc. - -#include -#include - -#include "mlx/backend/metal/kernels/defines.h" -#include "mlx/backend/metal/kernels/reduce.h" -#include "mlx/backend/metal/kernels/utils.h" - -using namespace metal; - -static constant uint8_t simd_size = 32; - -template -[[kernel]] void init_reduce( - device T *out [[buffer(0)]], - uint tid [[thread_position_in_grid]]) { - out[tid] = Op::init; -} - -#define instantiate_init_reduce(name, otype, op) \ - template [[host_name("i" #name)]] \ - [[kernel]] void init_reduce( \ - device otype *out [[buffer(1)]], \ - uint tid [[thread_position_in_grid]]); - -/////////////////////////////////////////////////////////////////////////////// -// All reduce -/////////////////////////////////////////////////////////////////////////////// - -template -inline U per_thread_all_reduce( - const device T *in, - const device size_t& in_size, - uint gid, - uint grid_size) { - Op op; - U total_val = Op::init; - - if (gid * N_READS < in_size) { - in += gid * N_READS; - - int r = 0; - for(; r < (int)ceildiv(in_size, grid_size * N_READS) - 1; r++) { - U vals[N_READS] = {op.init}; - - for(int i = 0; i < N_READS; i++) { - vals[i] = static_cast(in[i]); - } - for(int i = 0; i < N_READS; i++) { - total_val = op(vals[i], total_val); - } - - in += grid_size * N_READS; - } - - // Separate case for the last set as we close the reduction size - size_t curr_idx = (gid + r * (size_t)grid_size) * N_READS; - if (curr_idx < in_size) { - int max_reads = in_size - curr_idx; - T vals[N_READS]; - - for(int i = 0, idx = 0; i < N_READS; i++, idx++) { - idx = idx < max_reads ? idx : max_reads - 1; - vals[i] = in[idx]; - } - for(int i = 0; i < N_READS; i++) { - U val = i < max_reads ? vals[i] : Op::init; - total_val = op(static_cast(val), total_val); - } - } - } - - return total_val; -} - -// NB: This kernel assumes threads_per_threadgroup is at most -// 1024. This way with a simd_size of 32, we are guaranteed to -// complete the reduction in two steps of simd-level reductions. -template -[[kernel]] void all_reduce( - const device T *in [[buffer(0)]], - device mlx_atomic *out [[buffer(1)]], - const device size_t& in_size [[buffer(2)]], - uint gid [[thread_position_in_grid]], - uint lid [[thread_position_in_threadgroup]], - uint grid_size [[threads_per_grid]], - uint simd_per_group [[simdgroups_per_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - - Op op; - threadgroup U local_vals[simd_size]; - - U total_val = per_thread_all_reduce(in, in_size, gid, grid_size); - - // Reduction within simd group - total_val = op.simd_reduce(total_val); - if (simd_lane_id == 0) { - local_vals[simd_group_id] = total_val; - } - - // Reduction within thread group - threadgroup_barrier(mem_flags::mem_threadgroup); - total_val = lid < simd_per_group ? local_vals[lid] : op.init; - total_val = op.simd_reduce(total_val); - - // Reduction across threadgroups - if (lid == 0) { - op.atomic_update(out, total_val); - } -} - -template -[[kernel]] void all_reduce_no_atomics( - const device T *in [[buffer(0)]], - device U *out [[buffer(1)]], - const device size_t& in_size [[buffer(2)]], - uint gid [[thread_position_in_grid]], - uint lid [[thread_position_in_threadgroup]], - uint grid_size [[threads_per_grid]], - uint simd_per_group [[simdgroups_per_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint thread_group_id [[threadgroup_position_in_grid]]) { - - Op op; - threadgroup U local_vals[simd_size]; - - U total_val = per_thread_all_reduce(in, in_size, gid, grid_size); - - // Reduction within simd group (simd_add isn't supported for uint64/int64 types) - for (uint16_t lane_offset = simd_size/2; lane_offset > 0; lane_offset /= 2) { - total_val = op(total_val, simd_shuffle_down(total_val, lane_offset)); - } - // Write simd group reduction results to local memory - if (simd_lane_id == 0) { - local_vals[simd_group_id] = total_val; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Reduction of simdgroup reduction results within threadgroup. - total_val = lid < simd_per_group ? local_vals[lid] : op.init; - for (uint16_t lane_offset = simd_size/2; lane_offset > 0; lane_offset /= 2) { - total_val = op(total_val, simd_shuffle_down(total_val, lane_offset)); - } - - // Reduction across threadgroups - if (lid == 0) { - out[thread_group_id] = total_val; - } -} - -#define instantiate_all_reduce(name, itype, otype, op) \ - template [[host_name("all_reduce_" #name)]] \ - [[kernel]] void all_reduce( \ - const device itype *in [[buffer(0)]], \ - device mlx_atomic *out [[buffer(1)]], \ - const device size_t& in_size [[buffer(2)]], \ - uint gid [[thread_position_in_grid]], \ - uint lid [[thread_position_in_threadgroup]], \ - uint grid_size [[threads_per_grid]], \ - uint simd_per_group [[simdgroups_per_threadgroup]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]]); - -#define instantiate_all_reduce_no_atomics(name, itype, otype, op) \ - template [[host_name("all_reduce_no_atomics_" #name)]] \ - [[kernel]] void all_reduce_no_atomics( \ - const device itype *in [[buffer(0)]], \ - device otype *out [[buffer(1)]], \ - const device size_t& in_size [[buffer(2)]], \ - uint gid [[thread_position_in_grid]], \ - uint lid [[thread_position_in_threadgroup]], \ - uint grid_size [[threads_per_grid]], \ - uint simd_per_group [[simdgroups_per_threadgroup]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]], \ - uint thread_group_id [[threadgroup_position_in_grid]]); - -/////////////////////////////////////////////////////////////////////////////// -// Row atomics -/////////////////////////////////////////////////////////////////////////////// - -template -inline U per_thread_row_reduce( - const device T *in, - const constant size_t& reduction_size, - const constant size_t& out_size, - const constant int* shape, - const constant size_t* strides, - const constant int& ndim, - uint lsize_x, - uint lid_x, - uint2 tid) { - - Op op; - - // Each threadgroup handles 1 reduction - // TODO: Specializing elem_to_loc would be slightly faster - int idx = tid.y * out_size + tid.x; - int extra_offset = elem_to_loc(idx, shape, strides, ndim); - in += extra_offset + lid_x * N_READS; - - // The reduction is accumulated here - U total_val = Op::init; - - // Loop over the reduction size within thread group - int r = 0; - for (; r < (int)ceildiv(reduction_size, N_READS*lsize_x) - 1; r++) { - T vals[N_READS]; - for(int i = 0; i < N_READS; i++) { - vals[i] = in[i]; - } - for(int i = 0; i < N_READS; i++) { - total_val = op(static_cast(vals[i]), total_val); - } - - in += lsize_x * N_READS; - } - - // Separate case for the last set as we close the reduction size - size_t reduction_index = (lid_x + (size_t)lsize_x * r) * N_READS; - if(reduction_index < reduction_size) { - int max_reads = reduction_size - reduction_index; - - T vals[N_READS]; - for(int i = 0; i < N_READS; i++) { - int idx = min(i, max_reads - 1); - vals[i] = static_cast(in[idx]); - } - for(int i = 0; i < N_READS; i++) { - T val = i < max_reads ? vals[i] : Op::init; - total_val = op(static_cast(val), total_val); - } - } - - return total_val; -} - -template -[[kernel]] void row_reduce_general( - const device T *in [[buffer(0)]], - device mlx_atomic *out [[buffer(1)]], - const constant size_t& reduction_size [[buffer(2)]], - const constant size_t& out_size [[buffer(3)]], - const constant int* shape [[buffer(4)]], - const constant size_t* strides [[buffer(5)]], - const constant int& ndim [[buffer(6)]], - uint3 lid [[thread_position_in_threadgroup]], - uint3 lsize [[threads_per_threadgroup]], - uint3 tid [[threadgroup_position_in_grid]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_per_group [[simdgroups_per_threadgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - - Op op; - threadgroup U local_vals[simd_size]; - - U total_val = per_thread_row_reduce(in, reduction_size, out_size, shape, strides, ndim, lsize.x, lid.x, tid.xy); - - total_val = op.simd_reduce(total_val); - - // Prepare next level - if (simd_lane_id == 0) { - local_vals[simd_group_id] = total_val; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Reduction within thread group - // Only needed if multiple simd groups - if(reduction_size > simd_size) { - total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init; - total_val = op.simd_reduce(total_val); - } - // Update output - if (lid.x == 0) { - op.atomic_update(out, total_val, tid.x); - } -} - -template -[[kernel]] void row_reduce_general_no_atomics( - const device T *in [[buffer(0)]], - device U *out [[buffer(1)]], - const constant size_t& reduction_size [[buffer(2)]], - const constant size_t& out_size [[buffer(3)]], - const constant int* shape [[buffer(4)]], - const constant size_t* strides [[buffer(5)]], - const constant int& ndim [[buffer(6)]], - uint3 lid [[thread_position_in_threadgroup]], - uint3 lsize [[threads_per_threadgroup]], - uint3 gsize [[threads_per_grid]], - uint3 tid [[threadgroup_position_in_grid]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_per_group [[simdgroups_per_threadgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - - Op op; - - threadgroup U local_vals[simd_size]; - U total_val = per_thread_row_reduce(in, reduction_size, out_size, shape, strides, ndim, lsize.x, lid.x, tid.xy); - - // Reduction within simd group - simd_add isn't supported for int64 types - for (uint16_t i = simd_size/2; i > 0; i /= 2) { - total_val = op(total_val, simd_shuffle_down(total_val, i)); - } - - // Prepare next level - if (simd_lane_id == 0) { - local_vals[simd_group_id] = total_val; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Reduction within thread group - // Only needed if thread group has multiple simd groups - if(ceildiv(reduction_size, N_READS) > simd_size) { - total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init; - for (uint16_t i = simd_size/2; i > 0; i /= 2) { - total_val = op(total_val, simd_shuffle_down(total_val, i)); - } - } - // Write row reduce output for threadgroup with 1st thread in thread group - if (lid.x == 0) { - out[(ceildiv(gsize.y, lsize.y) * tid.x) + tid.y] = total_val; - } -} - -#define instantiate_row_reduce_general(name, itype, otype, op) \ - template [[host_name("row_reduce_general_" #name)]] \ - [[kernel]] void row_reduce_general( \ - const device itype *in [[buffer(0)]], \ - device mlx_atomic *out [[buffer(1)]], \ - const constant size_t& reduction_size [[buffer(2)]], \ - const constant size_t& out_size [[buffer(3)]], \ - const constant int* shape [[buffer(4)]], \ - const constant size_t* strides [[buffer(5)]], \ - const constant int& ndim [[buffer(6)]], \ - uint3 lid [[thread_position_in_threadgroup]], \ - uint3 lsize [[threads_per_threadgroup]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_per_group [[simdgroups_per_threadgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]]); - -#define instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \ - template [[host_name("row_reduce_general_no_atomics_" #name)]] \ - [[kernel]] void row_reduce_general_no_atomics( \ - const device itype *in [[buffer(0)]], \ - device otype *out [[buffer(1)]], \ - const constant size_t& reduction_size [[buffer(2)]], \ - const constant size_t& out_size [[buffer(3)]], \ - const constant int* shape [[buffer(4)]], \ - const constant size_t* strides [[buffer(5)]], \ - const constant int& ndim [[buffer(6)]], \ - uint3 lid [[thread_position_in_threadgroup]], \ - uint3 lsize [[threads_per_threadgroup]], \ - uint3 gsize [[threads_per_grid]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_per_group [[simdgroups_per_threadgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]]); - -/////////////////////////////////////////////////////////////////////////////// -// Column reduce -/////////////////////////////////////////////////////////////////////////////// - -template -inline U _contiguous_strided_reduce( - const device T *in, - threadgroup U *local_data, - uint in_idx, - uint reduction_size, - uint reduction_stride, - uint2 tid, - uint2 lid, - uint2 lsize) { - - Op op; - U total_val = Op::init; - - uint base_offset = (tid.y * lsize.y + lid.y) * N_READS; - for(uint r = 0; r < N_READS && (base_offset + r) < reduction_size; r++) { - uint offset = base_offset + r; - total_val = op(static_cast(total_val), in[in_idx + offset * reduction_stride]); - } - local_data[lsize.y * lid.x + lid.y] = total_val; - threadgroup_barrier(mem_flags::mem_threadgroup); - - U val = Op::init; - if(lid.y == 0) { - // Perform reduction across columns in thread group - for(uint i = 0; i < lsize.y; i++) { - val = op(val, local_data[lsize.y * lid.x + i]); - } - } - - return val; -} - -template -[[kernel]] void col_reduce_general( - const device T *in [[buffer(0)]], - device mlx_atomic *out [[buffer(1)]], - const constant size_t& reduction_size [[buffer(2)]], - const constant size_t& reduction_stride [[buffer(3)]], - const constant size_t& out_size [[buffer(4)]], - const constant int* shape [[buffer(5)]], - const constant size_t* strides [[buffer(6)]], - const constant int& ndim [[buffer(7)]], - threadgroup U *local_data [[threadgroup(0)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint3 lsize [[threads_per_threadgroup]]) { - auto out_idx = tid.x * lsize.x + lid.x; - auto in_idx = elem_to_loc( - out_idx + tid.z * out_size, - shape, - strides, - ndim - ); - - Op op; - if(out_idx < out_size) { - U val = _contiguous_strided_reduce( - in, - local_data, - in_idx, - reduction_size, - reduction_stride, - tid.xy, - lid.xy, - lsize.xy); - - // Write out reduction results generated by threadgroups working on specific output element, contiguously. - if (lid.y == 0) { - op.atomic_update(out, val, out_idx); - } - } -} - -template -[[kernel]] void col_reduce_general_no_atomics( - const device T *in [[buffer(0)]], - device U *out [[buffer(1)]], - const constant size_t& reduction_size [[buffer(2)]], - const constant size_t& reduction_stride [[buffer(3)]], - const constant size_t& out_size [[buffer(4)]], - const constant int* shape [[buffer(5)]], - const constant size_t* strides [[buffer(6)]], - const constant int& ndim [[buffer(7)]], - threadgroup U *local_data [[threadgroup(0)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint3 gid [[thread_position_in_grid]], - uint3 lsize [[threads_per_threadgroup]], - uint3 gsize [[threads_per_grid]]) { - auto out_idx = tid.x * lsize.x + lid.x; - auto in_idx = elem_to_loc( - out_idx + tid.z * out_size, - shape, - strides, - ndim - ); - - if(out_idx < out_size) { - U val = _contiguous_strided_reduce( - in, - local_data, - in_idx, - reduction_size, - reduction_stride, - tid.xy, - lid.xy, - lsize.xy); - - // Write out reduction results generated by threadgroups working on specific output element, contiguously. - if (lid.y == 0) { - uint tgsize_y = ceildiv(gsize.y, lsize.y); - uint tgsize_z = ceildiv(gsize.z, lsize.z); - out[tgsize_y * tgsize_z * gid.x + tgsize_y * tid.z + tid.y] = val; - } - } -} - -#define instantiate_col_reduce_general(name, itype, otype, op) \ - template [[host_name("col_reduce_general_" #name)]] \ - [[kernel]] void col_reduce_general( \ - const device itype *in [[buffer(0)]], \ - device mlx_atomic *out [[buffer(1)]], \ - const constant size_t& reduction_size [[buffer(2)]], \ - const constant size_t& reduction_stride [[buffer(3)]], \ - const constant size_t& out_size [[buffer(4)]], \ - const constant int* shape [[buffer(5)]], \ - const constant size_t* strides [[buffer(6)]], \ - const constant int& ndim [[buffer(7)]], \ - threadgroup otype *local_data [[threadgroup(0)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]], \ - uint3 lsize [[threads_per_threadgroup]]); - -#define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \ - template [[host_name("col_reduce_general_no_atomics_" #name)]] \ - [[kernel]] void col_reduce_general_no_atomics( \ - const device itype *in [[buffer(0)]], \ - device otype *out [[buffer(1)]], \ - const constant size_t& reduction_size [[buffer(2)]], \ - const constant size_t& reduction_stride [[buffer(3)]], \ - const constant size_t& out_size [[buffer(4)]], \ - const constant int* shape [[buffer(5)]], \ - const constant size_t* strides [[buffer(6)]], \ - const constant int& ndim [[buffer(7)]], \ - threadgroup otype *local_data [[threadgroup(0)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]], \ - uint3 gid [[thread_position_in_grid]], \ - uint3 lsize [[threads_per_threadgroup]], \ - uint3 gsize [[threads_per_grid]]); - -/////////////////////////////////////////////////////////////////////////////// -// Instantiations -/////////////////////////////////////////////////////////////////////////////// - -#define instantiate_reduce(name, itype, otype, op) \ - instantiate_all_reduce(name, itype, otype, op) \ - instantiate_row_reduce_general(name, itype, otype, op) \ - instantiate_col_reduce_general(name, itype, otype, op) - -#define instantiate_reduce_no_atomics(name, itype, otype, op) \ - instantiate_all_reduce_no_atomics(name, itype, otype, op) \ - instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \ - instantiate_col_reduce_general_no_atomics(name, itype, otype, op) - -#define instantiate_same_reduce_no_atomics(name, tname, type, op) \ - instantiate_init_reduce(name ##tname, type, op) \ - instantiate_reduce_no_atomics(name ##tname, type, type, op) - -#define instantiate_same_reduce(name, tname, type, op) \ - instantiate_init_reduce(name ##tname, type, op) \ - instantiate_reduce(name ##tname, type, type, op) - -#define instantiate_reduce_from_types_helper(name, tname, itype, otype, op) \ - instantiate_reduce(name ##tname, itype, otype, op) - -#define instantiate_reduce_from_types(name, otype, op) \ - instantiate_reduce_from_types_helper(name, bool_, bool, otype, op) \ - instantiate_reduce_from_types_helper(name, uint8, uint8_t, otype, op) \ - instantiate_reduce_from_types_helper(name, uint16, uint16_t, otype, op) \ - instantiate_reduce_from_types_helper(name, uint32, uint32_t, otype, op) \ - instantiate_reduce_from_types_helper(name, int8, int8_t, otype, op) \ - instantiate_reduce_from_types_helper(name, int16, int16_t, otype, op) \ - instantiate_reduce_from_types_helper(name, int32, int32_t, otype, op) \ - instantiate_reduce_from_types_helper(name, int64, int64_t, otype, op) \ - instantiate_reduce_from_types_helper(name, float16, half, otype, op) \ - instantiate_reduce_from_types_helper(name, float32, float, otype, op) \ - instantiate_reduce_from_types_helper(name, bfloat16, bfloat16_t, otype, op) - -// special case bool with larger output type -instantiate_reduce(sumbool_, bool, uint32_t, Sum) -instantiate_same_reduce(sum, uint8, uint8_t, Sum) -instantiate_same_reduce(sum, uint16, uint16_t, Sum) -instantiate_same_reduce(sum, uint32, uint32_t, Sum) -instantiate_same_reduce(sum, int8, int8_t, Sum) -instantiate_same_reduce(sum, int16, int16_t, Sum) -instantiate_same_reduce(sum, int32, int32_t, Sum) -instantiate_same_reduce(sum, float16, half, Sum) -instantiate_same_reduce(sum, float32, float, Sum) - -instantiate_same_reduce_no_atomics(sum, int64, int64_t, Sum) -instantiate_same_reduce_no_atomics(sum, uint64, uint64_t, Sum) - -instantiate_same_reduce(prod, uint8, uint8_t, Prod) -instantiate_same_reduce(prod, uint16, uint16_t, Prod) -instantiate_same_reduce(prod, uint32, uint32_t, Prod) -instantiate_same_reduce(prod, int8, int8_t, Prod) -instantiate_same_reduce(prod, int16, int16_t, Prod) -instantiate_same_reduce(prod, int32, int32_t, Prod) -instantiate_same_reduce(prod, float16, half, Prod) -instantiate_same_reduce(prod, float32, float, Prod) - -instantiate_same_reduce_no_atomics(prod, int64, int64_t, Prod) -instantiate_same_reduce_no_atomics(prod, uint64, uint64_t, Prod) - -instantiate_same_reduce(sum, bfloat16, bfloat16_t, Sum) -instantiate_same_reduce(prod, bfloat16, bfloat16_t, Prod) - -instantiate_init_reduce(andbool_, bool, And) -instantiate_reduce_from_types(and, bool, And) - -instantiate_init_reduce(orbool_, bool, Or) -instantiate_reduce_from_types(or, bool, Or) - -// Compiler segfaulted with the names "min" or "max" ... -instantiate_same_reduce(min_, uint8, uint8_t, Min) -instantiate_same_reduce(min_, uint16, uint16_t, Min) -instantiate_same_reduce(min_, uint32, uint32_t, Min) -instantiate_same_reduce(min_, int8, int8_t, Min) -instantiate_same_reduce(min_, int16, int16_t, Min) -instantiate_same_reduce(min_, int32, int32_t, Min) -instantiate_same_reduce(min_, float16, half, Min) -instantiate_same_reduce(min_, float32, float, Min) - -instantiate_same_reduce_no_atomics(min_, int64, int64_t, Min) -instantiate_same_reduce_no_atomics(min_, uint64, uint64_t, Min) - -instantiate_same_reduce(max_, uint8, uint8_t, Max) -instantiate_same_reduce(max_, uint16, uint16_t, Max) -instantiate_same_reduce(max_, uint32, uint32_t, Max) -instantiate_same_reduce(max_, int8, int8_t, Max) -instantiate_same_reduce(max_, int16, int16_t, Max) -instantiate_same_reduce(max_, int32, int32_t, Max) -instantiate_same_reduce(max_, float16, half, Max) -instantiate_same_reduce(max_, float32, float, Max) - -instantiate_same_reduce_no_atomics(max_, int64, int64_t, Max) -instantiate_same_reduce_no_atomics(max_, uint64, uint64_t, Max) - -instantiate_same_reduce(min_, bfloat16, bfloat16_t, Min) -instantiate_same_reduce(max_, bfloat16, bfloat16_t, Max) diff --git a/mlx/backend/metal/kernels/reduction/kernels/reduce_all.metal b/mlx/backend/metal/kernels/reduction/kernels/reduce_all.metal new file mode 100644 index 000000000..46c75301a --- /dev/null +++ b/mlx/backend/metal/kernels/reduction/kernels/reduce_all.metal @@ -0,0 +1,185 @@ +// Copyright © 2023-2024 Apple Inc. + +#include "mlx/backend/metal/kernels/reduction/utils.h" +#include "mlx/backend/metal/kernels/reduction/ops.h" +#include "mlx/backend/metal/kernels/reduction/reduce_inst.h" + +using namespace metal; + +/////////////////////////////////////////////////////////////////////////////// +// All reduce helper +/////////////////////////////////////////////////////////////////////////////// + +template +METAL_FUNC U per_thread_all_reduce( + const device T* in, + const device size_t& in_size, + uint gid, + uint grid_size) { + Op op; + U total_val = Op::init; + + if (gid * N_READS < in_size) { + in += gid * N_READS; + + int r = 0; + for (; r < (int)ceildiv(in_size, grid_size * N_READS) - 1; r++) { + U vals[N_READS] = {op.init}; + + for (int i = 0; i < N_READS; i++) { + vals[i] = static_cast(in[i]); + } + for (int i = 0; i < N_READS; i++) { + total_val = op(vals[i], total_val); + } + + in += grid_size * N_READS; + } + + // Separate case for the last set as we close the reduction size + size_t curr_idx = (gid + r * (size_t)grid_size) * N_READS; + if (curr_idx < in_size) { + int max_reads = in_size - curr_idx; + T vals[N_READS]; + + for (int i = 0, idx = 0; i < N_READS; i++, idx++) { + idx = idx < max_reads ? idx : max_reads - 1; + vals[i] = in[idx]; + } + for (int i = 0; i < N_READS; i++) { + U val = i < max_reads ? vals[i] : Op::init; + total_val = op(static_cast(val), total_val); + } + } + } + + return total_val; +} + +/////////////////////////////////////////////////////////////////////////////// +// All reduce kernel +/////////////////////////////////////////////////////////////////////////////// + + +// NB: This kernel assumes threads_per_threadgroup is at most +// 1024. This way with a simd_size of 32, we are guaranteed to +// complete the reduction in two steps of simd-level reductions. +template +[[kernel]] void all_reduce( + const device T *in [[buffer(0)]], + device mlx_atomic *out [[buffer(1)]], + const device size_t& in_size [[buffer(2)]], + uint gid [[thread_position_in_grid]], + uint lid [[thread_position_in_threadgroup]], + uint grid_size [[threads_per_grid]], + uint simd_per_group [[simdgroups_per_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + + Op op; + threadgroup U local_vals[simd_size]; + + U total_val = per_thread_all_reduce(in, in_size, gid, grid_size); + + // Reduction within simd group + total_val = op.simd_reduce(total_val); + if (simd_lane_id == 0) { + local_vals[simd_group_id] = total_val; + } + + // Reduction within thread group + threadgroup_barrier(mem_flags::mem_threadgroup); + total_val = lid < simd_per_group ? local_vals[lid] : op.init; + total_val = op.simd_reduce(total_val); + + // Reduction across threadgroups + if (lid == 0) { + op.atomic_update(out, total_val); + } +} + +template +[[kernel]] void all_reduce_no_atomics( + const device T *in [[buffer(0)]], + device U *out [[buffer(1)]], + const device size_t& in_size [[buffer(2)]], + uint gid [[thread_position_in_grid]], + uint lid [[thread_position_in_threadgroup]], + uint grid_size [[threads_per_grid]], + uint simd_per_group [[simdgroups_per_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint thread_group_id [[threadgroup_position_in_grid]]) { + + Op op; + threadgroup U local_vals[simd_size]; + + U total_val = per_thread_all_reduce(in, in_size, gid, grid_size); + + // Reduction within simd group (simd_add isn't supported for uint64/int64 types) + for (uint16_t lane_offset = simd_size/2; lane_offset > 0; lane_offset /= 2) { + total_val = op(total_val, simd_shuffle_down(total_val, lane_offset)); + } + // Write simd group reduction results to local memory + if (simd_lane_id == 0) { + local_vals[simd_group_id] = total_val; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Reduction of simdgroup reduction results within threadgroup. + total_val = lid < simd_per_group ? local_vals[lid] : op.init; + for (uint16_t lane_offset = simd_size/2; lane_offset > 0; lane_offset /= 2) { + total_val = op(total_val, simd_shuffle_down(total_val, lane_offset)); + } + + // Reduction across threadgroups + if (lid == 0) { + out[thread_group_id] = total_val; + } +} + +#define instantiate_all_reduce(name, itype, otype, op) \ + template [[host_name("all_reduce_" #name)]] \ + [[kernel]] void all_reduce( \ + const device itype *in [[buffer(0)]], \ + device mlx_atomic *out [[buffer(1)]], \ + const device size_t& in_size [[buffer(2)]], \ + uint gid [[thread_position_in_grid]], \ + uint lid [[thread_position_in_threadgroup]], \ + uint grid_size [[threads_per_grid]], \ + uint simd_per_group [[simdgroups_per_threadgroup]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]]); + +#define instantiate_all_reduce_no_atomics(name, itype, otype, op) \ + template [[host_name("all_reduce_no_atomics_" #name)]] \ + [[kernel]] void all_reduce_no_atomics( \ + const device itype *in [[buffer(0)]], \ + device otype *out [[buffer(1)]], \ + const device size_t& in_size [[buffer(2)]], \ + uint gid [[thread_position_in_grid]], \ + uint lid [[thread_position_in_threadgroup]], \ + uint grid_size [[threads_per_grid]], \ + uint simd_per_group [[simdgroups_per_threadgroup]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]], \ + uint thread_group_id [[threadgroup_position_in_grid]]); + +/////////////////////////////////////////////////////////////////////////////// +// Instantiations +/////////////////////////////////////////////////////////////////////////////// + +#define instantiate_same_all_reduce_helper(name, tname, type, op) \ + instantiate_all_reduce(name ##tname, type, type, op) + +#define instantiate_same_all_reduce_na_helper(name, tname, type, op) \ + instantiate_all_reduce_no_atomics(name ##tname, type, type, op) + +instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_types) +instantiate_reduce_ops(instantiate_same_all_reduce_na_helper, instantiate_reduce_helper_64b) + +instantiate_reduce_from_types(instantiate_all_reduce, and, bool, And) +instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or) + +// special case bool with larger output type +instantiate_all_reduce(sumbool_, bool, uint32_t, Sum) \ No newline at end of file diff --git a/mlx/backend/metal/kernels/reduction/kernels/reduce_col.metal b/mlx/backend/metal/kernels/reduction/kernels/reduce_col.metal new file mode 100644 index 000000000..07b889052 --- /dev/null +++ b/mlx/backend/metal/kernels/reduction/kernels/reduce_col.metal @@ -0,0 +1,184 @@ +// Copyright © 2023-2024 Apple Inc. + +#include "mlx/backend/metal/kernels/reduction/utils.h" +#include "mlx/backend/metal/kernels/reduction/ops.h" +#include "mlx/backend/metal/kernels/reduction/reduce_inst.h" + +using namespace metal; + +/////////////////////////////////////////////////////////////////////////////// +// Column reduce helper +/////////////////////////////////////////////////////////////////////////////// + +template +METAL_FUNC U _contiguous_strided_reduce( + const device T* in, + threadgroup U* local_data, + uint in_idx, + uint reduction_size, + uint reduction_stride, + uint2 tid, + uint2 lid, + uint2 lsize) { + Op op; + U total_val = Op::init; + + uint base_offset = (tid.y * lsize.y + lid.y) * N_READS; + for (uint r = 0; r < N_READS && (base_offset + r) < reduction_size; r++) { + uint offset = base_offset + r; + total_val = + op(static_cast(total_val), in[in_idx + offset * reduction_stride]); + } + local_data[lsize.y * lid.x + lid.y] = total_val; + threadgroup_barrier(mem_flags::mem_threadgroup); + + U val = Op::init; + if (lid.y == 0) { + // Perform reduction across columns in thread group + for (uint i = 0; i < lsize.y; i++) { + val = op(val, local_data[lsize.y * lid.x + i]); + } + } + + return val; +} + +/////////////////////////////////////////////////////////////////////////////// +// Column reduce kernel +/////////////////////////////////////////////////////////////////////////////// + +template +[[kernel]] void col_reduce_general( + const device T *in [[buffer(0)]], + device mlx_atomic *out [[buffer(1)]], + const constant size_t& reduction_size [[buffer(2)]], + const constant size_t& reduction_stride [[buffer(3)]], + const constant size_t& out_size [[buffer(4)]], + const constant int* shape [[buffer(5)]], + const constant size_t* strides [[buffer(6)]], + const constant int& ndim [[buffer(7)]], + threadgroup U *local_data [[threadgroup(0)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 lsize [[threads_per_threadgroup]]) { + auto out_idx = tid.x * lsize.x + lid.x; + auto in_idx = elem_to_loc( + out_idx + tid.z * out_size, + shape, + strides, + ndim + ); + + Op op; + if(out_idx < out_size) { + U val = _contiguous_strided_reduce( + in, + local_data, + in_idx, + reduction_size, + reduction_stride, + tid.xy, + lid.xy, + lsize.xy); + + // Write out reduction results generated by threadgroups working on specific output element, contiguously. + if (lid.y == 0) { + op.atomic_update(out, val, out_idx); + } + } +} + +template +[[kernel]] void col_reduce_general_no_atomics( + const device T *in [[buffer(0)]], + device U *out [[buffer(1)]], + const constant size_t& reduction_size [[buffer(2)]], + const constant size_t& reduction_stride [[buffer(3)]], + const constant size_t& out_size [[buffer(4)]], + const constant int* shape [[buffer(5)]], + const constant size_t* strides [[buffer(6)]], + const constant int& ndim [[buffer(7)]], + threadgroup U *local_data [[threadgroup(0)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 gid [[thread_position_in_grid]], + uint3 lsize [[threads_per_threadgroup]], + uint3 gsize [[threads_per_grid]]) { + auto out_idx = tid.x * lsize.x + lid.x; + auto in_idx = elem_to_loc( + out_idx + tid.z * out_size, + shape, + strides, + ndim + ); + + if(out_idx < out_size) { + U val = _contiguous_strided_reduce( + in, + local_data, + in_idx, + reduction_size, + reduction_stride, + tid.xy, + lid.xy, + lsize.xy); + + // Write out reduction results generated by threadgroups working on specific output element, contiguously. + if (lid.y == 0) { + uint tgsize_y = ceildiv(gsize.y, lsize.y); + uint tgsize_z = ceildiv(gsize.z, lsize.z); + out[tgsize_y * tgsize_z * gid.x + tgsize_y * tid.z + tid.y] = val; + } + } +} + +#define instantiate_col_reduce_general(name, itype, otype, op) \ + template [[host_name("col_reduce_general_" #name)]] \ + [[kernel]] void col_reduce_general( \ + const device itype *in [[buffer(0)]], \ + device mlx_atomic *out [[buffer(1)]], \ + const constant size_t& reduction_size [[buffer(2)]], \ + const constant size_t& reduction_stride [[buffer(3)]], \ + const constant size_t& out_size [[buffer(4)]], \ + const constant int* shape [[buffer(5)]], \ + const constant size_t* strides [[buffer(6)]], \ + const constant int& ndim [[buffer(7)]], \ + threadgroup otype *local_data [[threadgroup(0)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint3 lsize [[threads_per_threadgroup]]); + +#define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \ + template [[host_name("col_reduce_general_no_atomics_" #name)]] \ + [[kernel]] void col_reduce_general_no_atomics( \ + const device itype *in [[buffer(0)]], \ + device otype *out [[buffer(1)]], \ + const constant size_t& reduction_size [[buffer(2)]], \ + const constant size_t& reduction_stride [[buffer(3)]], \ + const constant size_t& out_size [[buffer(4)]], \ + const constant int* shape [[buffer(5)]], \ + const constant size_t* strides [[buffer(6)]], \ + const constant int& ndim [[buffer(7)]], \ + threadgroup otype *local_data [[threadgroup(0)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint3 gid [[thread_position_in_grid]], \ + uint3 lsize [[threads_per_threadgroup]], \ + uint3 gsize [[threads_per_grid]]); + +/////////////////////////////////////////////////////////////////////////////// +// Instantiations +/////////////////////////////////////////////////////////////////////////////// + +#define instantiate_same_col_reduce_helper(name, tname, type, op) \ + instantiate_col_reduce_general(name ##tname, type, type, op) + +#define instantiate_same_col_reduce_na_helper(name, tname, type, op) \ + instantiate_col_reduce_general_no_atomics(name ##tname, type, type, op) + +instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types) +instantiate_reduce_ops(instantiate_same_col_reduce_na_helper, instantiate_reduce_helper_64b) + +instantiate_col_reduce_general(sumbool_, bool, uint32_t, Sum) +instantiate_reduce_from_types(instantiate_col_reduce_general, and, bool, And) +instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or) \ No newline at end of file diff --git a/mlx/backend/metal/kernels/reduction/kernels/reduce_init.metal b/mlx/backend/metal/kernels/reduction/kernels/reduce_init.metal new file mode 100644 index 000000000..7e9bd06da --- /dev/null +++ b/mlx/backend/metal/kernels/reduction/kernels/reduce_init.metal @@ -0,0 +1,33 @@ +// Copyright © 2023-2024 Apple Inc. + +#include "mlx/backend/metal/kernels/reduction/utils.h" +#include "mlx/backend/metal/kernels/reduction/ops.h" +#include "mlx/backend/metal/kernels/reduction/reduce_inst.h" + +using namespace metal; + +/////////////////////////////////////////////////////////////////////////////// +// Reduce init +/////////////////////////////////////////////////////////////////////////////// + +template +[[kernel]] void init_reduce( + device T *out [[buffer(0)]], + uint tid [[thread_position_in_grid]]) { + out[tid] = Op::init; +} + +#define instantiate_init_reduce(name, otype, op) \ + template [[host_name("i" #name)]] \ + [[kernel]] void init_reduce( \ + device otype *out [[buffer(1)]], \ + uint tid [[thread_position_in_grid]]); + +#define instantiate_init_reduce_helper(name, tname, type, op) \ + instantiate_init_reduce(name ##tname, type, op) + +instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_types) +instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_64b) + +instantiate_init_reduce(andbool_, bool, And) +instantiate_init_reduce(orbool_, bool, Or) \ No newline at end of file diff --git a/mlx/backend/metal/kernels/reduction/kernels/reduce_row.metal b/mlx/backend/metal/kernels/reduction/kernels/reduce_row.metal new file mode 100644 index 000000000..1eac8de48 --- /dev/null +++ b/mlx/backend/metal/kernels/reduction/kernels/reduce_row.metal @@ -0,0 +1,369 @@ +// Copyright © 2023-2024 Apple Inc. + +#include "mlx/backend/metal/kernels/reduction/utils.h" +#include "mlx/backend/metal/kernels/reduction/ops.h" +#include "mlx/backend/metal/kernels/reduction/reduce_inst.h" + +using namespace metal; + +/////////////////////////////////////////////////////////////////////////////// +// Small row reductions +/////////////////////////////////////////////////////////////////////////////// + +// Each thread reduces for one output +template +[[kernel]] void row_reduce_general_small( + const device T *in [[buffer(0)]], + device U *out [[buffer(1)]], + const constant size_t& reduction_size [[buffer(2)]], + const constant size_t& out_size [[buffer(3)]], + const constant size_t& non_row_reductions [[buffer(4)]], + const constant int* shape [[buffer(5)]], + const constant size_t* strides [[buffer(6)]], + const constant int& ndim [[buffer(7)]], + uint lid [[thread_position_in_grid]]) { + + Op op; + + uint out_idx = lid; + + if(out_idx >= out_size) { + return; + } + + U total_val = Op::init; + + for(short r = 0; r < short(non_row_reductions); r++) { + uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim); + const device T * in_row = in + in_idx; + + for(short i = 0; i < short(reduction_size); i++) { + total_val = op(static_cast(in_row[i]), total_val); + } + } + + out[out_idx] = total_val; +} + +// Each simdgroup reduces for one output +template +[[kernel]] void row_reduce_general_med( + const device T *in [[buffer(0)]], + device U *out [[buffer(1)]], + const constant size_t& reduction_size [[buffer(2)]], + const constant size_t& out_size [[buffer(3)]], + const constant size_t& non_row_reductions [[buffer(4)]], + const constant int* shape [[buffer(5)]], + const constant size_t* strides [[buffer(6)]], + const constant int& ndim [[buffer(7)]], + uint tid [[threadgroup_position_in_grid]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_per_group [[dispatch_simdgroups_per_threadgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + + Op op; + + uint out_idx = simd_per_group * tid + simd_group_id; + + if(out_idx >= out_size) { + return; + } + + U total_val = Op::init; + + if(short(non_row_reductions) == 1) { + uint in_idx = elem_to_loc(out_idx, shape, strides, ndim); + const device T * in_row = in + in_idx; + + for(short i = simd_lane_id; i < short(reduction_size); i += 32) { + total_val = op(static_cast(in_row[i]), total_val); + } + } + + else if (short(non_row_reductions) >= 32) { + + for(short r = simd_lane_id; r < short(non_row_reductions); r+=32) { + + uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim); + const device T * in_row = in + in_idx; + + for(short i = 0; i < short(reduction_size); i++) { + total_val = op(static_cast(in_row[i]), total_val); + } + + } + + } + + else { + + const short n_reductions = short(reduction_size) * short(non_row_reductions); + const short reductions_per_thread = (n_reductions + simd_size - 1) / simd_size; + + const short r_st = simd_lane_id / reductions_per_thread; + const short r_ed = short(non_row_reductions); + const short r_jump = simd_size / reductions_per_thread; + + const short i_st = simd_lane_id % reductions_per_thread; + const short i_ed = short(reduction_size); + const short i_jump = reductions_per_thread; + + for(short r = r_st; r < r_ed; r += r_jump) { + + uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim); + const device T * in_row = in + in_idx; + + for(short i = i_st; i < i_ed; i += i_jump) { + total_val = op(static_cast(in_row[i]), total_val); + } + + } + + } + + + total_val = op.simd_reduce(total_val); + + if(simd_lane_id == 0) { + out[out_idx] = total_val; + } +} + +#define instantiate_row_reduce_small(name, itype, otype, op) \ + template[[host_name("row_reduce_general_small_" #name)]] \ + [[kernel]] void row_reduce_general_small( \ + const device itype *in [[buffer(0)]], \ + device otype *out [[buffer(1)]], \ + const constant size_t& reduction_size [[buffer(2)]], \ + const constant size_t& out_size [[buffer(3)]], \ + const constant size_t& non_row_reductions [[buffer(4)]], \ + const constant int* shape [[buffer(5)]], \ + const constant size_t* strides [[buffer(6)]], \ + const constant int& ndim [[buffer(7)]], \ + uint lid [[thread_position_in_grid]]); \ + template[[host_name("row_reduce_general_med_" #name)]] \ + [[kernel]] void row_reduce_general_med( \ + const device itype *in [[buffer(0)]], \ + device otype *out [[buffer(1)]], \ + const constant size_t& reduction_size [[buffer(2)]], \ + const constant size_t& out_size [[buffer(3)]], \ + const constant size_t& non_row_reductions [[buffer(4)]], \ + const constant int* shape [[buffer(5)]], \ + const constant size_t* strides [[buffer(6)]], \ + const constant int& ndim [[buffer(7)]], \ + uint tid [[threadgroup_position_in_grid]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_per_group [[dispatch_simdgroups_per_threadgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]]); + +/////////////////////////////////////////////////////////////////////////////// +// Large row reductions +/////////////////////////////////////////////////////////////////////////////// + +template +METAL_FUNC U per_thread_row_reduce( + const device T* in, + const constant size_t& reduction_size, + const constant size_t& out_size, + const constant int* shape, + const constant size_t* strides, + const constant int& ndim, + uint lsize_x, + uint lid_x, + uint2 tid) { + Op op; + + // Each threadgroup handles 1 reduction + // TODO: Specializing elem_to_loc would be slightly faster + int idx = tid.y * out_size + tid.x; + int extra_offset = elem_to_loc(idx, shape, strides, ndim); + in += extra_offset + lid_x * N_READS; + + // The reduction is accumulated here + U total_val = Op::init; + + // Loop over the reduction size within thread group + int r = 0; + for (; r < (int)ceildiv(reduction_size, N_READS * lsize_x) - 1; r++) { + T vals[N_READS]; + for (int i = 0; i < N_READS; i++) { + vals[i] = in[i]; + } + for (int i = 0; i < N_READS; i++) { + total_val = op(static_cast(vals[i]), total_val); + } + + in += lsize_x * N_READS; + } + + // Separate case for the last set as we close the reduction size + size_t reduction_index = (lid_x + (size_t)lsize_x * r) * N_READS; + if (reduction_index < reduction_size) { + int max_reads = reduction_size - reduction_index; + + T vals[N_READS]; + for (int i = 0; i < N_READS; i++) { + int idx = min(i, max_reads - 1); + vals[i] = static_cast(in[idx]); + } + for (int i = 0; i < N_READS; i++) { + T val = i < max_reads ? vals[i] : Op::init; + total_val = op(static_cast(val), total_val); + } + } + + return total_val; +} + +template +[[kernel]] void row_reduce_general( + const device T *in [[buffer(0)]], + device mlx_atomic *out [[buffer(1)]], + const constant size_t& reduction_size [[buffer(2)]], + const constant size_t& out_size [[buffer(3)]], + const constant size_t& non_row_reductions [[buffer(4)]], + const constant int* shape [[buffer(5)]], + const constant size_t* strides [[buffer(6)]], + const constant int& ndim [[buffer(7)]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 lsize [[threads_per_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_per_group [[simdgroups_per_threadgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + + (void)non_row_reductions; + + Op op; + threadgroup U local_vals[simd_size]; + + U total_val = per_thread_row_reduce(in, reduction_size, out_size, shape, strides, ndim, lsize.x, lid.x, tid.xy); + + total_val = op.simd_reduce(total_val); + + // Prepare next level + if (simd_lane_id == 0) { + local_vals[simd_group_id] = total_val; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Reduction within thread group + // Only needed if multiple simd groups + if(reduction_size > simd_size) { + total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init; + total_val = op.simd_reduce(total_val); + } + // Update output + if (lid.x == 0) { + op.atomic_update(out, total_val, tid.x); + } +} + +template +[[kernel]] void row_reduce_general_no_atomics( + const device T *in [[buffer(0)]], + device U *out [[buffer(1)]], + const constant size_t& reduction_size [[buffer(2)]], + const constant size_t& out_size [[buffer(3)]], + const constant size_t& non_row_reductions [[buffer(4)]], + const constant int* shape [[buffer(5)]], + const constant size_t* strides [[buffer(6)]], + const constant int& ndim [[buffer(7)]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 lsize [[threads_per_threadgroup]], + uint3 gsize [[threads_per_grid]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_per_group [[simdgroups_per_threadgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + + (void)non_row_reductions; + + Op op; + + threadgroup U local_vals[simd_size]; + U total_val = per_thread_row_reduce(in, reduction_size, out_size, shape, strides, ndim, lsize.x, lid.x, tid.xy); + + // Reduction within simd group - simd_add isn't supported for int64 types + for (uint16_t i = simd_size/2; i > 0; i /= 2) { + total_val = op(total_val, simd_shuffle_down(total_val, i)); + } + + // Prepare next level + if (simd_lane_id == 0) { + local_vals[simd_group_id] = total_val; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Reduction within thread group + // Only needed if thread group has multiple simd groups + if(ceildiv(reduction_size, N_READS) > simd_size) { + total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init; + for (uint16_t i = simd_size/2; i > 0; i /= 2) { + total_val = op(total_val, simd_shuffle_down(total_val, i)); + } + } + // Write row reduce output for threadgroup with 1st thread in thread group + if (lid.x == 0) { + out[(ceildiv(gsize.y, lsize.y) * tid.x) + tid.y] = total_val; + } +} + +#define instantiate_row_reduce_general(name, itype, otype, op) \ + instantiate_row_reduce_small(name, itype, otype, op) \ + template [[host_name("row_reduce_general_" #name)]] \ + [[kernel]] void row_reduce_general( \ + const device itype *in [[buffer(0)]], \ + device mlx_atomic *out [[buffer(1)]], \ + const constant size_t& reduction_size [[buffer(2)]], \ + const constant size_t& out_size [[buffer(3)]], \ + const constant size_t& non_row_reductions [[buffer(4)]], \ + const constant int* shape [[buffer(5)]], \ + const constant size_t* strides [[buffer(6)]], \ + const constant int& ndim [[buffer(7)]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint3 lsize [[threads_per_threadgroup]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_per_group [[simdgroups_per_threadgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]]); + +#define instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \ + instantiate_row_reduce_small(name, itype, otype, op) \ + template [[host_name("row_reduce_general_no_atomics_" #name)]] \ + [[kernel]] void row_reduce_general_no_atomics( \ + const device itype *in [[buffer(0)]], \ + device otype *out [[buffer(1)]], \ + const constant size_t& reduction_size [[buffer(2)]], \ + const constant size_t& out_size [[buffer(3)]], \ + const constant size_t& non_row_reductions [[buffer(4)]], \ + const constant int* shape [[buffer(5)]], \ + const constant size_t* strides [[buffer(6)]], \ + const constant int& ndim [[buffer(7)]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint3 lsize [[threads_per_threadgroup]], \ + uint3 gsize [[threads_per_grid]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_per_group [[simdgroups_per_threadgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]]); + + +/////////////////////////////////////////////////////////////////////////////// +// Instantiations +/////////////////////////////////////////////////////////////////////////////// + +#define instantiate_same_row_reduce_helper(name, tname, type, op) \ + instantiate_row_reduce_general(name ##tname, type, type, op) + +#define instantiate_same_row_reduce_na_helper(name, tname, type, op) \ + instantiate_row_reduce_general_no_atomics(name ##tname, type, type, op) + +instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_types) +instantiate_reduce_ops(instantiate_same_row_reduce_na_helper, instantiate_reduce_helper_64b) + + +instantiate_reduce_from_types(instantiate_row_reduce_general, and, bool, And) +instantiate_reduce_from_types(instantiate_row_reduce_general, or, bool, Or) + +instantiate_row_reduce_general(sumbool_, bool, uint32_t, Sum) \ No newline at end of file diff --git a/mlx/backend/metal/kernels/reduce.h b/mlx/backend/metal/kernels/reduction/ops.h similarity index 98% rename from mlx/backend/metal/kernels/reduce.h rename to mlx/backend/metal/kernels/reduction/ops.h index 70701aebd..ea0c495d9 100644 --- a/mlx/backend/metal/kernels/reduce.h +++ b/mlx/backend/metal/kernels/reduction/ops.h @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #pragma once diff --git a/mlx/backend/metal/kernels/reduction/reduce_inst.h b/mlx/backend/metal/kernels/reduction/reduce_inst.h new file mode 100644 index 000000000..593db7e62 --- /dev/null +++ b/mlx/backend/metal/kernels/reduction/reduce_inst.h @@ -0,0 +1,71 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include +#include + +#include "mlx/backend/metal/kernels/defines.h" +#include "mlx/backend/metal/kernels/reduction/ops.h" + +#define instantiate_reduce_helper_floats(inst_f, name, op) \ + inst_f(name, float16, half, op) inst_f(name, float32, float, op) \ + inst_f(name, bfloat16, bfloat16_t, op) + +#define instantiate_reduce_helper_uints(inst_f, name, op) \ + inst_f(name, uint8, uint8_t, op) inst_f(name, uint16, uint16_t, op) \ + inst_f(name, uint32, uint32_t, op) + +#define instantiate_reduce_helper_ints(inst_f, name, op) \ + inst_f(name, int8, int8_t, op) inst_f(name, int16, int16_t, op) \ + inst_f(name, int32, int32_t, op) + +#define instantiate_reduce_helper_64b(inst_f, name, op) \ + inst_f(name, int64, int64_t, op) inst_f(name, uint64, uint64_t, op) + +#define instantiate_reduce_helper_types(inst_f, name, op) \ + instantiate_reduce_helper_floats(inst_f, name, op) \ + instantiate_reduce_helper_uints(inst_f, name, op) \ + instantiate_reduce_helper_ints(inst_f, name, op) + +#define instantiate_reduce_ops(inst_f, type_f) \ + type_f(inst_f, sum, Sum) type_f(inst_f, prod, Prod) \ + type_f(inst_f, min_, Min) type_f(inst_f, max_, Max) + +// Special case for bool reductions +#define instantiate_reduce_from_types_helper( \ + inst_f, name, tname, itype, otype, op) \ + inst_f(name##tname, itype, otype, op) + +#define instantiate_reduce_from_types(inst_f, name, otype, op) \ + instantiate_reduce_from_types_helper(inst_f, name, bool_, bool, otype, op) \ + instantiate_reduce_from_types_helper( \ + inst_f, name, uint8, uint8_t, otype, op) \ + instantiate_reduce_from_types_helper( \ + inst_f, name, uint16, uint16_t, otype, op) \ + instantiate_reduce_from_types_helper( \ + inst_f, name, uint32, uint32_t, otype, op) \ + instantiate_reduce_from_types_helper( \ + inst_f, name, int8, int8_t, otype, op) \ + instantiate_reduce_from_types_helper( \ + inst_f, name, int16, int16_t, otype, op) \ + instantiate_reduce_from_types_helper( \ + inst_f, name, int32, int32_t, otype, op) \ + instantiate_reduce_from_types_helper( \ + inst_f, name, int64, int64_t, otype, op) \ + instantiate_reduce_from_types_helper( \ + inst_f, name, float16, half, otype, op) \ + instantiate_reduce_from_types_helper( \ + inst_f, \ + name, \ + float32, \ + float, \ + otype, \ + op) \ + instantiate_reduce_from_types_helper( \ + inst_f, \ + name, \ + bfloat16, \ + bfloat16_t, \ + otype, \ + op) \ No newline at end of file diff --git a/mlx/backend/metal/kernels/reduction/utils.h b/mlx/backend/metal/kernels/reduction/utils.h new file mode 100644 index 000000000..6665a3eca --- /dev/null +++ b/mlx/backend/metal/kernels/reduction/utils.h @@ -0,0 +1,14 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include +#include + +#include "mlx/backend/metal/kernels/defines.h" +#include "mlx/backend/metal/kernels/steel/utils.h" +#include "mlx/backend/metal/kernels/utils.h" + +#include "mlx/backend/metal/kernels/reduction/ops.h" + +static constant constexpr const uint8_t simd_size = 32; \ No newline at end of file diff --git a/mlx/backend/metal/kernels/scatter.metal b/mlx/backend/metal/kernels/scatter.metal index 071effeea..d8cec1336 100644 --- a/mlx/backend/metal/kernels/scatter.metal +++ b/mlx/backend/metal/kernels/scatter.metal @@ -4,7 +4,7 @@ #include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/indexing.h" -#include "mlx/backend/metal/kernels/reduce.h" +#include "mlx/backend/metal/kernels/reduction/ops.h" #include "mlx/backend/metal/kernels/utils.h" using namespace metal; diff --git a/mlx/backend/metal/reduce.cpp b/mlx/backend/metal/reduce.cpp index 9b7e729da..8a19c602e 100644 --- a/mlx/backend/metal/reduce.cpp +++ b/mlx/backend/metal/reduce.cpp @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include #include @@ -130,15 +130,8 @@ void row_reduce_general_dispatch( const Stream& s) { Dtype out_dtype = out.dtype(); bool is_out_64b_int = is_64b_int(out_dtype); - auto kernel = (is_out_64b_int) - ? d.get_kernel( - "row_reduce_general_no_atomics_" + op_name + type_to_name(in)) - : d.get_kernel("row_reduce_general_" + op_name + type_to_name(in)); - - compute_encoder->setComputePipelineState(kernel); // Prepare the arguments for the kernel - int n_reads = REDUCE_N_READS; size_t reduction_size = plan.shape.back(); auto shape = plan.shape; auto strides = plan.strides; @@ -160,32 +153,72 @@ void row_reduce_general_dispatch( } int ndim = shape.size(); - // Each thread group is responsible for 1 output - NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); - thread_group_size = - std::min((reduction_size + n_reads - 1) / n_reads, thread_group_size); + // Determine dispatch kernel + std::ostringstream kname; - // Align thread group size with simd_size - uint simd_size = kernel->threadExecutionWidth(); - thread_group_size = - (thread_group_size + simd_size - 1) / simd_size * simd_size; - assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup()); + bool is_small = non_row_reductions * reduction_size < 32; + bool is_med = non_row_reductions * reduction_size <= 256; + is_out_64b_int &= !is_small && !is_med; - // Launch enough thread groups for each output - size_t n_threads = out.size() * thread_group_size; - MTL::Size grid_dims = MTL::Size(n_threads, non_row_reductions, 1); - MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); + std::string small_desc = "_"; + if (is_small) { + small_desc = "_small_"; + } else if (is_med) { + small_desc = "_med_"; + } - if (is_out_64b_int == false || non_row_reductions == 1) { + small_desc = is_out_64b_int ? "_no_atomics_" : small_desc; + + kname << "row_reduce_general" << small_desc << op_name << type_to_name(in); + + auto kernel = d.get_kernel(kname.str()); + compute_encoder->setComputePipelineState(kernel); + + // Get dispatch grid dims + MTL::Size grid_dims; + MTL::Size group_dims; + + // Each thread handles one output + if (is_small) { + grid_dims = MTL::Size(out.size(), 1, 1); + group_dims = MTL::Size(std::min(1024ul, out.size()), 1, 1); + } + // Each simdgroup handles one output + else if (is_med) { + grid_dims = MTL::Size(out.size() * 32, 1, 1); + group_dims = MTL::Size(std::min(8ul, out.size()) * 32, 1, 1); + } + // Each theadgroup handles one output + else { + int n_reads = REDUCE_N_READS; + NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); + thread_group_size = + std::min((reduction_size + n_reads - 1) / n_reads, thread_group_size); + + // Align thread group size with simd_size + uint simd_size = kernel->threadExecutionWidth(); + thread_group_size = + (thread_group_size + simd_size - 1) / simd_size * simd_size; + assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup()); + + // Launch enough thread groups for each output + size_t n_threads = out.size() * thread_group_size; + grid_dims = MTL::Size(n_threads, non_row_reductions, 1); + group_dims = MTL::Size(thread_group_size, 1, 1); + } + + // Dispatch kernel + if (!is_out_64b_int || non_row_reductions == 1) { // Set the arguments for the kernel set_array_buffer(compute_encoder, in, 0); set_array_buffer(compute_encoder, out, 1); compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); compute_encoder->setBytes(&out_size, sizeof(size_t), 3); - compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4); + compute_encoder->setBytes(&non_row_reductions, sizeof(size_t), 4); + compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5); compute_encoder->setBytes( - strides.data(), strides.size() * sizeof(size_t), 5); - compute_encoder->setBytes(&ndim, sizeof(int), 6); + strides.data(), strides.size() * sizeof(size_t), 6); + compute_encoder->setBytes(&ndim, sizeof(int), 7); compute_encoder->dispatchThreads(grid_dims, group_dims); } else { @@ -203,10 +236,11 @@ void row_reduce_general_dispatch( set_array_buffer(compute_encoder, intermediate, 1); compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); compute_encoder->setBytes(&out_size, sizeof(size_t), 3); - compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4); + compute_encoder->setBytes(&non_row_reductions, sizeof(size_t), 4); + compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5); compute_encoder->setBytes( - strides.data(), strides.size() * sizeof(size_t), 5); - compute_encoder->setBytes(&ndim, sizeof(int), 6); + strides.data(), strides.size() * sizeof(size_t), 6); + compute_encoder->setBytes(&ndim, sizeof(int), 7); compute_encoder->dispatchThreads(grid_dims, group_dims); // Set up second dispatch @@ -230,24 +264,27 @@ void row_reduce_general_dispatch( set_array_buffer(compute_encoder, out, 1); compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); compute_encoder->setBytes(&out_size, sizeof(size_t), 3); + compute_encoder->setBytes(&non_row_reductions, sizeof(size_t), 4); compute_encoder->setBytes( - new_shape.data(), new_shape.size() * sizeof(int), 4); + new_shape.data(), new_shape.size() * sizeof(int), 5); compute_encoder->setBytes( - new_strides.data(), new_strides.size() * sizeof(size_t), 5); - compute_encoder->setBytes(&ndim, sizeof(int), 6); + new_strides.data(), new_strides.size() * sizeof(size_t), 6); + compute_encoder->setBytes(&ndim, sizeof(int), 7); // Each thread group is responsible for 1 output - thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); + int n_reads = REDUCE_N_READS; + size_t thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); thread_group_size = std::min((reduction_size + n_reads - 1) / n_reads, thread_group_size); // Align thread group size with simd_size + uint simd_size = kernel->threadExecutionWidth(); thread_group_size = (thread_group_size + simd_size - 1) / simd_size * simd_size; assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup()); // Launch enough thread groups for each output - n_threads = thread_group_size; + size_t n_threads = thread_group_size; grid_dims = MTL::Size(n_threads, out.size(), 1); group_dims = MTL::Size(thread_group_size, 1, 1); @@ -417,11 +454,12 @@ void strided_reduce_general_dispatch( set_array_buffer(compute_encoder, out, 1); compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); compute_encoder->setBytes(&out_size, sizeof(size_t), 3); + compute_encoder->setBytes(&reduction_size, sizeof(size_t), 4); compute_encoder->setBytes( - new_shape.data(), new_shape.size() * sizeof(int), 4); + new_shape.data(), new_shape.size() * sizeof(int), 5); compute_encoder->setBytes( - new_strides.data(), new_strides.size() * sizeof(size_t), 5); - compute_encoder->setBytes(&ndim, sizeof(int), 6); + new_strides.data(), new_strides.size() * sizeof(size_t), 6); + compute_encoder->setBytes(&ndim, sizeof(int), 7); // Each thread group is responsible for 1 output size_t n_reads = REDUCE_N_READS;