diff --git a/mlx/backend/common/reduce.h b/mlx/backend/common/reduce.h index 0e9088579..1815d65bf 100644 --- a/mlx/backend/common/reduce.h +++ b/mlx/backend/common/reduce.h @@ -49,7 +49,7 @@ struct ReductionPlan { ReductionPlan(ReductionOpType type_) : type(type_) {} }; -ReductionPlan get_reduction_plan(const array& x, const std::vector axes); +ReductionPlan get_reduction_plan(const array& x, const std::vector& axes); // Helper for the ndimensional strided loop // Should this be in utils? diff --git a/mlx/backend/common/reduce_utils.cpp b/mlx/backend/common/reduce_utils.cpp index 47b0f6c32..b15bc9e18 100644 --- a/mlx/backend/common/reduce_utils.cpp +++ b/mlx/backend/common/reduce_utils.cpp @@ -19,7 +19,7 @@ std::pair, std::vector> shapes_without_reduction_axes( return std::make_pair(shape, strides); } -ReductionPlan get_reduction_plan(const array& x, const std::vector axes) { +ReductionPlan get_reduction_plan(const array& x, const std::vector& axes) { // The data is all there and we are reducing over everything if (x.size() == x.data_size() && axes.size() == x.ndim() && x.flags().contiguous) { @@ -41,6 +41,14 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector axes) { } } + // Remove singleton axes from the plan + for (int i = shape.size() - 1; i >= 0; i--) { + if (shape[i] == 1) { + shape.erase(shape.begin() + i); + strides.erase(strides.begin() + i); + } + } + if (strides.back() == 1) { return ReductionPlan(ContiguousReduce, shape, strides); } else if (strides.back() > 1) { @@ -63,10 +71,14 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector axes) { // have a contiguous reduction. std::vector> reductions; for (auto a : axes) { - reductions.push_back(std::make_pair(x.shape(a), x.strides()[a])); + if (x.shape(a) > 1) { + reductions.push_back(std::make_pair(x.shape(a), x.strides()[a])); + } } std::sort(reductions.begin(), reductions.end(), [](auto a, auto b) { - return a.second > b.second; + bool a_is_zero = a.second == 0; + bool b_is_zero = b.second == 0; + return (a_is_zero != b_is_zero) ? a.second < b.second : a.second > b.second; }); // Extract the two smallest and try to merge them in case the contiguous // reduction can be bigger than just the last axis. @@ -98,16 +110,33 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector axes) { // strides.back() are contiguous. if (strides.back() > 1) { int size = 1; + bool have_expand = false; for (int i = x.ndim() - 1; i >= 0; i--) { if (axes.back() == i) { continue; } - if (x.strides()[i] != size) { + + size_t stride_i = x.strides()[i]; + int shape_i = x.shape(i); + if (stride_i == 0) { + if (shape_i == 1) { + continue; + } + + have_expand = true; break; } - size *= x.shape(i); + + if (stride_i != size && shape_i != 1) { + break; + } + size *= shape_i; } - if (size >= strides.back()) { + // In the case of an expanded dimension we are being conservative and + // require the smallest reduction stride to be smaller than the maximum row + // contiguous size. The reason is that we can't easily know if the reduced + // axis is before or after an expanded dimension. + if (size > strides.back() || (size == strides.back() && !have_expand)) { return ReductionPlan(GeneralStridedReduce, shape, strides); } } diff --git a/mlx/backend/common/utils.h b/mlx/backend/common/utils.h index 14252f378..f8d2c9117 100644 --- a/mlx/backend/common/utils.h +++ b/mlx/backend/common/utils.h @@ -104,6 +104,33 @@ inline auto collapse_contiguous_dims(Arrays&&... xs) { std::vector{std::forward(xs)...}); } +// The single array version of the above. +inline std::tuple, std::vector> +collapse_contiguous_dims( + const std::vector& shape, + const std::vector& strides) { + std::vector collapsed_shape; + std::vector collapsed_strides; + + if (shape.size() > 0) { + collapsed_shape.push_back(shape[0]); + collapsed_strides.push_back(strides[0]); + for (int i = 1; i < shape.size(); i++) { + if (strides[i] * shape[i] != collapsed_strides.back() || + collapsed_shape.back() * static_cast(shape[i]) > + std::numeric_limits::max()) { + collapsed_shape.push_back(shape[i]); + collapsed_strides.push_back(strides[i]); + } else { + collapsed_shape.back() *= shape[i]; + collapsed_strides.back() = strides[i]; + } + } + } + + return std::make_tuple(collapsed_shape, collapsed_strides); +} + template inline auto check_contiguity( const std::vector& shape, diff --git a/mlx/backend/metal/kernels/atomic.h b/mlx/backend/metal/kernels/atomic.h index 5f526de04..93952c2cf 100644 --- a/mlx/backend/metal/kernels/atomic.h +++ b/mlx/backend/metal/kernels/atomic.h @@ -37,13 +37,13 @@ struct mlx_atomic>> { template , bool> = true> METAL_FUNC T -mlx_atomic_load_explicit(device mlx_atomic* object, uint offset) { +mlx_atomic_load_explicit(device mlx_atomic* object, size_t offset) { return atomic_load_explicit(&(object[offset].val), memory_order_relaxed); } template , bool> = true> METAL_FUNC void -mlx_atomic_store_explicit(device mlx_atomic* object, T val, uint offset) { +mlx_atomic_store_explicit(device mlx_atomic* object, T val, size_t offset) { atomic_store_explicit(&(object[offset].val), val, memory_order_relaxed); } @@ -51,13 +51,15 @@ template , bool> = true> METAL_FUNC void mlx_atomic_fetch_and_explicit( device mlx_atomic* object, T val, - uint offset) { + size_t offset) { atomic_fetch_and_explicit(&(object[offset].val), val, memory_order_relaxed); } template , bool> = true> -METAL_FUNC void -mlx_atomic_fetch_or_explicit(device mlx_atomic* object, T val, uint offset) { +METAL_FUNC void mlx_atomic_fetch_or_explicit( + device mlx_atomic* object, + T val, + size_t offset) { atomic_fetch_or_explicit(&(object[offset].val), val, memory_order_relaxed); } @@ -65,7 +67,7 @@ template , bool> = true> METAL_FUNC void mlx_atomic_fetch_min_explicit( device mlx_atomic* object, T val, - uint offset) { + size_t offset) { atomic_fetch_min_explicit(&(object[offset].val), val, memory_order_relaxed); } @@ -73,7 +75,7 @@ template , bool> = true> METAL_FUNC void mlx_atomic_fetch_max_explicit( device mlx_atomic* object, T val, - uint offset) { + size_t offset) { atomic_fetch_max_explicit(&(object[offset].val), val, memory_order_relaxed); } @@ -81,7 +83,7 @@ template , bool> = true> METAL_FUNC void mlx_atomic_fetch_add_explicit( device mlx_atomic* object, T val, - uint offset) { + size_t offset) { atomic_fetch_add_explicit(&(object[offset].val), val, memory_order_relaxed); } @@ -89,7 +91,7 @@ template , bool> = true> METAL_FUNC void mlx_atomic_fetch_mul_explicit( device mlx_atomic* object, T val, - uint offset) { + size_t offset) { T expected = mlx_atomic_load_explicit(object, offset); while (!mlx_atomic_compare_exchange_weak_explicit( object, &expected, val * expected, offset)) { @@ -101,7 +103,7 @@ METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit( device mlx_atomic* object, thread T* expected, T val, - uint offset) { + size_t offset) { return atomic_compare_exchange_weak_explicit( &(object[offset].val), expected, @@ -115,7 +117,7 @@ template <> METAL_FUNC void mlx_atomic_fetch_min_explicit( device mlx_atomic* object, float val, - uint offset) { + size_t offset) { float expected = mlx_atomic_load_explicit(object, offset); while (val < expected) { if (mlx_atomic_compare_exchange_weak_explicit( @@ -130,7 +132,7 @@ template <> METAL_FUNC void mlx_atomic_fetch_max_explicit( device mlx_atomic* object, float val, - uint offset) { + size_t offset) { float expected = mlx_atomic_load_explicit(object, offset); while (val > expected) { if (mlx_atomic_compare_exchange_weak_explicit( @@ -157,7 +159,7 @@ union uint_or_packed { template struct mlx_atomic_update_helper { - uint operator()(uint_or_packed init, T update, uint elem_offset) { + uint operator()(uint_or_packed init, T update, size_t elem_offset) { Op op; init.val[elem_offset] = op(update, init.val[elem_offset]); return init.bits; @@ -168,9 +170,9 @@ template METAL_FUNC void mlx_atomic_update_and_store( device mlx_atomic* object, T update, - uint offset) { - uint pack_offset = offset / packing_size; - uint elem_offset = offset % packing_size; + size_t offset) { + size_t pack_offset = offset / packing_size; + size_t elem_offset = offset % packing_size; mlx_atomic_update_helper helper; uint_or_packed expected; @@ -251,9 +253,9 @@ struct __Min { template , bool> = true> METAL_FUNC T -mlx_atomic_load_explicit(device mlx_atomic* object, uint offset) { - uint pack_offset = offset / sizeof(T); - uint elem_offset = offset % sizeof(T); +mlx_atomic_load_explicit(device mlx_atomic* object, size_t offset) { + size_t pack_offset = offset / sizeof(T); + size_t elem_offset = offset % sizeof(T); uint_or_packed packed_val; packed_val.bits = atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed); @@ -262,7 +264,7 @@ mlx_atomic_load_explicit(device mlx_atomic* object, uint offset) { template , bool> = true> METAL_FUNC void -mlx_atomic_store_explicit(device mlx_atomic* object, T val, uint offset) { +mlx_atomic_store_explicit(device mlx_atomic* object, T val, size_t offset) { mlx_atomic_update_and_store>(object, val, offset); } @@ -270,9 +272,9 @@ template , bool> = true> METAL_FUNC void mlx_atomic_fetch_and_explicit( device mlx_atomic* object, T val, - uint offset) { - uint pack_offset = offset / packing_size; - uint elem_offset = offset % packing_size; + size_t offset) { + size_t pack_offset = offset / packing_size; + size_t elem_offset = offset % packing_size; uint_or_packed identity; identity.bits = __UINT32_MAX__; identity.val[elem_offset] = val; @@ -282,10 +284,12 @@ METAL_FUNC void mlx_atomic_fetch_and_explicit( } template , bool> = true> -METAL_FUNC void -mlx_atomic_fetch_or_explicit(device mlx_atomic* object, T val, uint offset) { - uint pack_offset = offset / packing_size; - uint elem_offset = offset % packing_size; +METAL_FUNC void mlx_atomic_fetch_or_explicit( + device mlx_atomic* object, + T val, + size_t offset) { + size_t pack_offset = offset / packing_size; + size_t elem_offset = offset % packing_size; uint_or_packed identity; identity.bits = 0; identity.val[elem_offset] = val; @@ -298,7 +302,7 @@ template , bool> = true> METAL_FUNC void mlx_atomic_fetch_min_explicit( device mlx_atomic* object, T val, - uint offset) { + size_t offset) { mlx_atomic_update_and_store>(object, val, offset); } @@ -306,7 +310,7 @@ template , bool> = true> METAL_FUNC void mlx_atomic_fetch_max_explicit( device mlx_atomic* object, T val, - uint offset) { + size_t offset) { mlx_atomic_update_and_store>(object, val, offset); } @@ -314,7 +318,7 @@ template , bool> = true> METAL_FUNC void mlx_atomic_fetch_add_explicit( device mlx_atomic* object, T val, - uint offset) { + size_t offset) { mlx_atomic_update_and_store>(object, val, offset); } @@ -322,7 +326,7 @@ template , bool> = true> METAL_FUNC void mlx_atomic_fetch_mul_explicit( device mlx_atomic* object, T val, - uint offset) { + size_t offset) { mlx_atomic_update_and_store>(object, val, offset); } @@ -331,7 +335,7 @@ METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit( device mlx_atomic* object, thread uint* expected, uint val, - uint offset) { + size_t offset) { return atomic_compare_exchange_weak_explicit( &(object[offset].val), expected, diff --git a/mlx/backend/metal/kernels/complex.h b/mlx/backend/metal/kernels/complex.h index df69bfe9f..fe8ec5c0f 100644 --- a/mlx/backend/metal/kernels/complex.h +++ b/mlx/backend/metal/kernels/complex.h @@ -23,6 +23,8 @@ struct complex64_t { // Constructors constexpr complex64_t(float real, float imag) : real(real), imag(imag) {}; + constexpr complex64_t() : real(0), imag(0) {}; + constexpr complex64_t() threadgroup : real(0), imag(0) {}; // Conversions to complex64_t template < diff --git a/mlx/backend/metal/kernels/defines.h b/mlx/backend/metal/kernels/defines.h index 3e98252b5..c369adb7e 100644 --- a/mlx/backend/metal/kernels/defines.h +++ b/mlx/backend/metal/kernels/defines.h @@ -9,7 +9,8 @@ #endif static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4; -static MTL_CONST constexpr int REDUCE_N_READS = 16; +static MTL_CONST constexpr int REDUCE_N_READS = 4; +static MTL_CONST constexpr int REDUCE_N_WRITES = 4; static MTL_CONST constexpr int SOFTMAX_N_READS = 4; static MTL_CONST constexpr int RMS_N_READS = 4; static MTL_CONST constexpr int RMS_LOOPED_LIMIT = 4096; diff --git a/mlx/backend/metal/kernels/reduce.metal b/mlx/backend/metal/kernels/reduce.metal index 0e4650015..3d7e92c52 100644 --- a/mlx/backend/metal/kernels/reduce.metal +++ b/mlx/backend/metal/kernels/reduce.metal @@ -28,7 +28,8 @@ #define instantiate_reduce_helper_64b(inst_f, name, op) \ inst_f(name, int64, int64_t, op) \ - inst_f(name, uint64, uint64_t, op) + inst_f(name, uint64, uint64_t, op) \ + inst_f(name, complex64, complex64_t, op) #define instantiate_reduce_helper_types(inst_f, name, op) \ instantiate_reduce_helper_floats(inst_f, name, op) \ @@ -97,40 +98,24 @@ instantiate_init_reduce(andbool_, bool, And) instantiate_init_reduce(orbool_, bool, Or) #define instantiate_all_reduce(name, itype, otype, op) \ - template [[host_name("all_reduce_" #name)]] [[kernel]] void \ - all_reduce( \ + template [[host_name("all_reduce_" #name)]] \ + [[kernel]] void all_reduce( \ const device itype* in [[buffer(0)]], \ - device mlx_atomic* out [[buffer(1)]], \ - const device size_t& in_size [[buffer(2)]], \ - uint gid [[thread_position_in_grid]], \ - uint lid [[thread_position_in_threadgroup]], \ - uint grid_size [[threads_per_grid]], \ + device otype* out [[buffer(1)]], \ + const constant size_t& in_size [[buffer(2)]], \ + const constant size_t& row_size [[buffer(3)]], \ + uint3 gid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint3 lsize [[threads_per_threadgroup]], \ uint simd_per_group [[simdgroups_per_threadgroup]], \ uint simd_lane_id [[thread_index_in_simdgroup]], \ uint simd_group_id [[simdgroup_index_in_threadgroup]]); -#define instantiate_all_reduce_no_atomics(name, itype, otype, op) \ - template [[host_name("allNoAtomics_reduce_" #name)]] [[kernel]] void \ - all_reduce_no_atomics( \ - const device itype* in [[buffer(0)]], \ - device otype* out [[buffer(1)]], \ - const device size_t& in_size [[buffer(2)]], \ - uint gid [[thread_position_in_grid]], \ - uint lid [[thread_position_in_threadgroup]], \ - uint grid_size [[threads_per_grid]], \ - uint simd_per_group [[simdgroups_per_threadgroup]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]], \ - uint thread_group_id [[threadgroup_position_in_grid]]); - #define instantiate_same_all_reduce_helper(name, tname, type, op) \ instantiate_all_reduce(name##tname, type, type, op) -#define instantiate_same_all_reduce_na_helper(name, tname, type, op) \ - instantiate_all_reduce_no_atomics(name##tname, type, type, op) - instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_types) -instantiate_reduce_ops(instantiate_same_all_reduce_na_helper, instantiate_reduce_helper_64b) +instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_64b) instantiate_reduce_from_types(instantiate_all_reduce, and, bool, And) instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or) @@ -138,153 +123,143 @@ instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or) // special case bool with larger output type instantiate_all_reduce(sumbool_, bool, uint32_t, Sum) -#define instantiate_col_reduce_general(name, itype, otype, op) \ - template [[host_name("colGeneral_reduce_" #name)]] [[kernel]] void \ - col_reduce_general( \ - const device itype* in [[buffer(0)]], \ - device mlx_atomic* out [[buffer(1)]], \ - const constant size_t& reduction_size [[buffer(2)]], \ - const constant size_t& reduction_stride [[buffer(3)]], \ - const constant size_t& out_size [[buffer(4)]], \ - const constant int* shape [[buffer(5)]], \ - const constant size_t* strides [[buffer(6)]], \ - const constant int& ndim [[buffer(7)]], \ - threadgroup otype* local_data [[threadgroup(0)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]], \ - uint3 lsize [[threads_per_threadgroup]]); +#define instantiate_col_reduce_small(name, itype, otype, op, dim) \ + template [[host_name("colSmall" #dim "_reduce_" #name)]] \ + [[kernel]] void col_reduce_small( \ + const device itype* in [[buffer(0)]], \ + device otype* out [[buffer(1)]], \ + const constant size_t& reduction_size [[buffer(2)]], \ + const constant size_t& reduction_stride [[buffer(3)]], \ + const constant int* shape [[buffer(4)]], \ + const constant size_t* strides [[buffer(5)]], \ + const constant int& ndim [[buffer(6)]], \ + const constant int* reduce_shape [[buffer(7)]], \ + const constant size_t* reduce_strides [[buffer(8)]], \ + const constant int& reduce_ndim [[buffer(9)]], \ + const constant size_t& non_col_reductions [[buffer(10)]], \ + uint3 gid [[threadgroup_position_in_grid]], \ + uint3 gsize [[threadgroups_per_grid]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]], \ + uint3 tid [[thread_position_in_grid]], \ + uint3 tsize [[threads_per_grid]]); -#define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \ - template \ - [[host_name("colGeneralNoAtomics_reduce_" #name)]] [[kernel]] void \ - col_reduce_general_no_atomics( \ - const device itype* in [[buffer(0)]], \ - device otype* out [[buffer(1)]], \ - const constant size_t& reduction_size [[buffer(2)]], \ - const constant size_t& reduction_stride [[buffer(3)]], \ - const constant size_t& out_size [[buffer(4)]], \ - const constant int* shape [[buffer(5)]], \ - const constant size_t* strides [[buffer(6)]], \ - const constant int& ndim [[buffer(7)]], \ - threadgroup otype* local_data [[threadgroup(0)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]], \ - uint3 gid [[thread_position_in_grid]], \ - uint3 lsize [[threads_per_threadgroup]], \ - uint3 gsize [[threads_per_grid]]); +#define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \ + template [[host_name("colLooped" #dim "_" #bm "_" #bn "_reduce_" #name)]] \ + [[kernel]] void col_reduce_looped( \ + const device itype* in [[buffer(0)]], \ + device otype* out [[buffer(1)]], \ + const constant size_t& reduction_size [[buffer(2)]], \ + const constant size_t& reduction_stride [[buffer(3)]], \ + const constant int* shape [[buffer(4)]], \ + const constant size_t* strides [[buffer(5)]], \ + const constant int& ndim [[buffer(6)]], \ + const constant int* reduce_shape [[buffer(7)]], \ + const constant size_t* reduce_strides [[buffer(8)]], \ + const constant int& reduce_ndim [[buffer(9)]], \ + const constant size_t& non_col_reductions [[buffer(10)]], \ + uint3 gid [[threadgroup_position_in_grid]], \ + uint3 gsize [[threadgroups_per_grid]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]]); -#define instantiate_col_reduce_small(name, itype, otype, op) \ - template [[host_name("colSmall_reduce_" #name)]] [[kernel]] void \ - col_reduce_small( \ - const device itype* in [[buffer(0)]], \ - device otype* out [[buffer(1)]], \ - const constant size_t& reduction_size [[buffer(2)]], \ - const constant size_t& reduction_stride [[buffer(3)]], \ - const constant size_t& out_size [[buffer(4)]], \ - const constant int* shape [[buffer(5)]], \ - const constant size_t* strides [[buffer(6)]], \ - const constant int& ndim [[buffer(7)]], \ - const constant size_t& non_col_reductions [[buffer(8)]], \ - const constant int* non_col_shapes [[buffer(9)]], \ - const constant size_t* non_col_strides [[buffer(10)]], \ - const constant int& non_col_ndim [[buffer(11)]], \ - uint tid [[thread_position_in_grid]]); +#define instantiate_col_reduce_looped(name, itype, otype, op, dim) \ + instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 8, 128) \ + instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 32, 32) -#define instantiate_same_col_reduce_helper(name, tname, type, op) \ - instantiate_col_reduce_small(name ##tname, type, type, op) \ - instantiate_col_reduce_general(name ##tname, type, type, op) +#define instantiate_col_reduce_general(name, itype, otype, op) \ + instantiate_col_reduce_small(name, itype, otype, op, 0) \ + instantiate_col_reduce_small(name, itype, otype, op, 1) \ + instantiate_col_reduce_small(name, itype, otype, op, 2) \ + instantiate_col_reduce_small(name, itype, otype, op, 3) \ + instantiate_col_reduce_small(name, itype, otype, op, 4) \ + instantiate_col_reduce_looped(name, itype, otype, op, 0) \ + instantiate_col_reduce_looped(name, itype, otype, op, 1) \ + instantiate_col_reduce_looped(name, itype, otype, op, 2) \ + instantiate_col_reduce_looped(name, itype, otype, op, 3) \ + instantiate_col_reduce_looped(name, itype, otype, op, 4) -#define instantiate_same_col_reduce_na_helper(name, tname, type, op) \ - instantiate_col_reduce_small(name ##tname, type, type, op) \ - instantiate_col_reduce_general_no_atomics(name ##tname, type, type, op) +#define instantiate_same_col_reduce_helper(name, tname, type, op) \ + instantiate_col_reduce_general(name##tname, type, type, op) instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types) -instantiate_reduce_ops(instantiate_same_col_reduce_na_helper, instantiate_reduce_helper_64b) +instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_64b) instantiate_col_reduce_general(sumbool_, bool, uint32_t, Sum) instantiate_reduce_from_types(instantiate_col_reduce_general, and, bool, And) instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or) -instantiate_col_reduce_small(sumbool_, bool, uint32_t, Sum) -instantiate_reduce_from_types(instantiate_col_reduce_small, and, bool, And) -instantiate_reduce_from_types(instantiate_col_reduce_small, or, bool, Or) - -#define instantiate_row_reduce_small(name, itype, otype, op) \ - template [[host_name("rowGeneralSmall_reduce_" #name)]] [[kernel]] void \ - row_reduce_general_small( \ - const device itype* in [[buffer(0)]], \ - device otype* out [[buffer(1)]], \ - const constant size_t& reduction_size [[buffer(2)]], \ - const constant size_t& out_size [[buffer(3)]], \ - const constant size_t& non_row_reductions [[buffer(4)]], \ - const constant int* shape [[buffer(5)]], \ - const constant size_t* strides [[buffer(6)]], \ - const constant int& ndim [[buffer(7)]], \ - uint lid [[thread_position_in_grid]]); \ - template [[host_name("rowGeneralMed_reduce_" #name)]] [[kernel]] void \ - row_reduce_general_med( \ - const device itype* in [[buffer(0)]], \ - device otype* out [[buffer(1)]], \ - const constant size_t& reduction_size [[buffer(2)]], \ - const constant size_t& out_size [[buffer(3)]], \ - const constant size_t& non_row_reductions [[buffer(4)]], \ - const constant int* shape [[buffer(5)]], \ - const constant size_t* strides [[buffer(6)]], \ - const constant int& ndim [[buffer(7)]], \ - uint tid [[threadgroup_position_in_grid]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_per_group [[dispatch_simdgroups_per_threadgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]]); +#define instantiate_row_reduce_small(name, itype, otype, op, dim) \ + template [[host_name("rowSmall" #dim "_reduce_" #name)]] [[kernel]] void \ + row_reduce_small( \ + const device itype* in [[buffer(0)]], \ + device otype* out [[buffer(1)]], \ + const constant size_t& row_size [[buffer(2)]], \ + const constant size_t& non_row_reductions [[buffer(3)]], \ + const constant int* shape [[buffer(4)]], \ + const constant size_t* strides [[buffer(5)]], \ + const constant int& ndim [[buffer(6)]], \ + const constant int* reduce_shape [[buffer(7)]], \ + const constant size_t* reduce_strides [[buffer(8)]], \ + const constant int& reduce_ndim [[buffer(9)]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint3 gid [[threadgroup_position_in_grid]], \ + uint3 gsize [[threadgroups_per_grid]], \ + uint3 tid [[thread_position_in_grid]], \ + uint3 tsize [[threads_per_grid]]); +#define instantiate_row_reduce_looped(name, itype, otype, op, dim) \ + template \ + [[host_name("rowLooped" #dim "_reduce_" #name)]] [[kernel]] void \ + row_reduce_looped( \ + const device itype* in [[buffer(0)]], \ + device otype* out [[buffer(1)]], \ + const constant size_t& row_size [[buffer(2)]], \ + const constant size_t& non_row_reductions [[buffer(3)]], \ + const constant int* shape [[buffer(4)]], \ + const constant size_t* strides [[buffer(5)]], \ + const constant int& ndim [[buffer(6)]], \ + const constant int* reduce_shape [[buffer(7)]], \ + const constant size_t* reduce_strides [[buffer(8)]], \ + const constant int& reduce_ndim [[buffer(9)]], \ + uint3 gid [[threadgroup_position_in_grid]], \ + uint3 gsize [[threadgroups_per_grid]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint3 lsize [[threads_per_threadgroup]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_per_group [[simdgroups_per_threadgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]]); #define instantiate_row_reduce_general(name, itype, otype, op) \ - instantiate_row_reduce_small(name, itype, otype, op) \ + instantiate_row_reduce_small(name, itype, otype, op, 0) \ + instantiate_row_reduce_small(name, itype, otype, op, 1) \ + instantiate_row_reduce_small(name, itype, otype, op, 2) \ + instantiate_row_reduce_small(name, itype, otype, op, 3) \ + instantiate_row_reduce_small(name, itype, otype, op, 4) \ + instantiate_row_reduce_looped(name, itype, otype, op, 0) \ + instantiate_row_reduce_looped(name, itype, otype, op, 1) \ + instantiate_row_reduce_looped(name, itype, otype, op, 2) \ + instantiate_row_reduce_looped(name, itype, otype, op, 3) \ + instantiate_row_reduce_looped(name, itype, otype, op, 4) \ template \ - [[host_name("rowGeneral_reduce_" #name)]] [[kernel]] void \ - row_reduce_general( \ + [[host_name("rowSimple_reduce_" #name)]] [[kernel]] void \ + row_reduce_simple( \ const device itype* in [[buffer(0)]], \ - device mlx_atomic* out [[buffer(1)]], \ + device otype* out [[buffer(1)]], \ const constant size_t& reduction_size [[buffer(2)]], \ const constant size_t& out_size [[buffer(3)]], \ - const constant size_t& non_row_reductions [[buffer(4)]], \ - const constant int* shape [[buffer(5)]], \ - const constant size_t* strides [[buffer(6)]], \ - const constant int& ndim [[buffer(7)]], \ + uint3 gid [[threadgroup_position_in_grid]], \ + uint3 gsize [[threadgroups_per_grid]], \ uint3 lid [[thread_position_in_threadgroup]], \ uint3 lsize [[threads_per_threadgroup]], \ - uint3 tid [[threadgroup_position_in_grid]], \ uint simd_lane_id [[thread_index_in_simdgroup]], \ uint simd_per_group [[simdgroups_per_threadgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]]); - -#define instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \ - instantiate_row_reduce_small(name, itype, otype, op) \ - template \ - [[host_name("rowGeneralNoAtomics_reduce_" #name)]] [[kernel]] void \ - row_reduce_general_no_atomics( \ - const device itype* in [[buffer(0)]], \ - device otype* out [[buffer(1)]], \ - const constant size_t& reduction_size [[buffer(2)]], \ - const constant size_t& out_size [[buffer(3)]], \ - const constant size_t& non_row_reductions [[buffer(4)]], \ - const constant int* shape [[buffer(5)]], \ - const constant size_t* strides [[buffer(6)]], \ - const constant int& ndim [[buffer(7)]], \ - uint3 lid [[thread_position_in_threadgroup]], \ - uint3 lsize [[threads_per_threadgroup]], \ - uint3 gsize [[threads_per_grid]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_per_group [[simdgroups_per_threadgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]]); + uint simd_group_id [[simdgroup_index_in_threadgroup]]); \ #define instantiate_same_row_reduce_helper(name, tname, type, op) \ instantiate_row_reduce_general(name##tname, type, type, op) -#define instantiate_same_row_reduce_na_helper(name, tname, type, op) \ - instantiate_row_reduce_general_no_atomics(name##tname, type, type, op) - instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_types) -instantiate_reduce_ops(instantiate_same_row_reduce_na_helper, instantiate_reduce_helper_64b) +instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_64b) instantiate_reduce_from_types(instantiate_row_reduce_general, and, bool, And) instantiate_reduce_from_types(instantiate_row_reduce_general, or, bool, Or) diff --git a/mlx/backend/metal/kernels/reduction/ops.h b/mlx/backend/metal/kernels/reduction/ops.h index 48a0c87e1..68ed11986 100644 --- a/mlx/backend/metal/kernels/reduction/ops.h +++ b/mlx/backend/metal/kernels/reduction/ops.h @@ -5,6 +5,20 @@ #include #include +#define DEFINE_SIMD_REDUCE() \ + template = true> \ + T simd_reduce(T val) { \ + return simd_reduce_impl(val); \ + } \ + \ + template = true> \ + T simd_reduce(T val) { \ + for (short i = simd_size / 2; i > 0; i /= 2) { \ + val = operator()(val, simd_shuffle_down(val, i)); \ + } \ + return val; \ + } + static constant constexpr const uint8_t simd_size = 32; union bool4_or_uint { @@ -14,14 +28,16 @@ union bool4_or_uint { struct None { template - void atomic_update(device mlx_atomic* out, T val, uint offset = 0) { + void atomic_update(device mlx_atomic* out, T val, size_t offset = 0) { mlx_atomic_store_explicit(out, val, offset); } }; template struct And { - bool simd_reduce(bool val) { + DEFINE_SIMD_REDUCE() + + bool simd_reduce_impl(bool val) { return simd_all(val); } @@ -31,7 +47,7 @@ struct And { device mlx_atomic* out, bool val, int elem_idx, - int offset = 0) { + size_t offset = 0) { if (!val) { bool4_or_uint update; update.b = {true, true, true, true}; @@ -40,7 +56,8 @@ struct And { } } - void atomic_update(device mlx_atomic* out, bool val, uint offset = 0) { + void + atomic_update(device mlx_atomic* out, bool val, size_t offset = 0) { if (!val) { mlx_atomic_store_explicit(out, val, offset); } @@ -59,7 +76,9 @@ struct And { template struct Or { - bool simd_reduce(bool val) { + DEFINE_SIMD_REDUCE() + + bool simd_reduce_impl(bool val) { return simd_any(val); } @@ -68,8 +87,8 @@ struct Or { void atomic_update( device mlx_atomic* out, bool val, - uint elem_idx, - uint offset = 0) { + int elem_idx, + size_t offset = 0) { if (val) { bool4_or_uint update; update.b = {false, false, false, false}; @@ -78,7 +97,8 @@ struct Or { } } - void atomic_update(device mlx_atomic* out, bool val, uint offset = 0) { + void + atomic_update(device mlx_atomic* out, bool val, size_t offset = 0) { if (val) { mlx_atomic_store_explicit(out, val, offset); } @@ -97,15 +117,17 @@ struct Or { template struct Sum { + DEFINE_SIMD_REDUCE() + template - T simd_reduce(T val) { + T simd_reduce_impl(T val) { return simd_sum(val); } static constexpr constant U init = U(0); template - void atomic_update(device mlx_atomic* out, T val, uint offset = 0) { + void atomic_update(device mlx_atomic* out, T val, size_t offset = 0) { mlx_atomic_fetch_add_explicit(out, val, offset); } @@ -117,15 +139,17 @@ struct Sum { template struct Prod { + DEFINE_SIMD_REDUCE() + template - T simd_reduce(T val) { + T simd_reduce_impl(T val) { return simd_product(val); } static constexpr constant U init = U(1); template - void atomic_update(device mlx_atomic* out, T val, uint offset = 0) { + void atomic_update(device mlx_atomic* out, T val, size_t offset = 0) { mlx_atomic_fetch_mul_explicit(out, val, offset); } @@ -137,15 +161,17 @@ struct Prod { template struct Min { + DEFINE_SIMD_REDUCE() + template - T simd_reduce(T val) { + T simd_reduce_impl(T val) { return simd_min(val); } static constexpr constant U init = Limits::max; template - void atomic_update(device mlx_atomic* out, T val, uint offset = 0) { + void atomic_update(device mlx_atomic* out, T val, size_t offset = 0) { mlx_atomic_fetch_min_explicit(out, val, offset); } @@ -157,15 +183,17 @@ struct Min { template struct Max { + DEFINE_SIMD_REDUCE() + template - T simd_reduce(T val) { + T simd_reduce_impl(T val) { return simd_max(val); } static constexpr constant U init = Limits::min; template - void atomic_update(device mlx_atomic* out, T val, uint offset = 0) { + void atomic_update(device mlx_atomic* out, T val, size_t offset = 0) { mlx_atomic_fetch_max_explicit(out, val, offset); } diff --git a/mlx/backend/metal/kernels/reduction/reduce_all.h b/mlx/backend/metal/kernels/reduction/reduce_all.h index e5928b615..381d5e20b 100644 --- a/mlx/backend/metal/kernels/reduction/reduce_all.h +++ b/mlx/backend/metal/kernels/reduction/reduce_all.h @@ -1,135 +1,61 @@ // Copyright © 2023-2024 Apple Inc. -/////////////////////////////////////////////////////////////////////////////// -// All reduce helper -/////////////////////////////////////////////////////////////////////////////// - -template -METAL_FUNC U per_thread_all_reduce( - const device T* in, - const device size_t& in_size, - uint gid, - uint grid_size) { - Op op; - U total_val = Op::init; - - if (gid * N_READS < in_size) { - in += gid * N_READS; - - int r = 0; - for (; r < (int)ceildiv(in_size, grid_size * N_READS) - 1; r++) { - U vals[N_READS] = {op.init}; - - for (int i = 0; i < N_READS; i++) { - vals[i] = static_cast(in[i]); - } - for (int i = 0; i < N_READS; i++) { - total_val = op(vals[i], total_val); - } - - in += grid_size * N_READS; - } - - // Separate case for the last set as we close the reduction size - size_t curr_idx = (gid + r * (size_t)grid_size) * N_READS; - if (curr_idx < in_size) { - int max_reads = in_size - curr_idx; - T vals[N_READS]; - - for (int i = 0, idx = 0; i < N_READS; i++, idx++) { - idx = idx < max_reads ? idx : max_reads - 1; - vals[i] = in[idx]; - } - for (int i = 0; i < N_READS; i++) { - U val = i < max_reads ? vals[i] : Op::init; - total_val = op(static_cast(val), total_val); - } - } - } - - return total_val; -} - -/////////////////////////////////////////////////////////////////////////////// -// All reduce kernel -/////////////////////////////////////////////////////////////////////////////// - -// NB: This kernel assumes threads_per_threadgroup is at most -// 1024. This way with a simd_size of 32, we are guaranteed to -// complete the reduction in two steps of simd-level reductions. template [[kernel]] void all_reduce( const device T* in [[buffer(0)]], - device mlx_atomic* out [[buffer(1)]], - const device size_t& in_size [[buffer(2)]], - uint gid [[thread_position_in_grid]], - uint lid [[thread_position_in_threadgroup]], - uint grid_size [[threads_per_grid]], + device U* out [[buffer(1)]], + const constant size_t& in_size [[buffer(2)]], + const constant size_t& row_size [[buffer(3)]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 lsize [[threads_per_threadgroup]], uint simd_per_group [[simdgroups_per_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { Op op; - threadgroup U local_vals[simd_size]; + threadgroup U shared_vals[simd_size]; - U total_val = - per_thread_all_reduce(in, in_size, gid, grid_size); + U total = Op::init; + int64_t start_idx = gid.y * row_size; + int64_t actual_row = + (start_idx + row_size <= in_size) ? row_size : in_size - start_idx; + int64_t blocks = actual_row / (lsize.x * N_READS); + int extra = actual_row - blocks * (lsize.x * N_READS); + extra -= lid.x * N_READS; + start_idx += lid.x * N_READS; + in += start_idx; + + if (extra >= N_READS) { + blocks++; + extra = 0; + } + + for (int64_t b = 0; b < blocks; b++) { + for (int i = 0; i < N_READS; i++) { + total = op(static_cast(in[i]), total); + } + in += lsize.x * N_READS; + } + if (extra > 0) { + for (int i = 0; i < extra; i++) { + total = op(static_cast(in[i]), total); + } + } // Reduction within simd group - total_val = op.simd_reduce(total_val); - if (simd_lane_id == 0) { - local_vals[simd_group_id] = total_val; + total = op.simd_reduce(total); + if (simd_per_group > 1) { + if (simd_lane_id == 0) { + shared_vals[simd_group_id] = total; + } + + // Reduction within thread group + threadgroup_barrier(mem_flags::mem_threadgroup); + total = lid.x < simd_per_group ? shared_vals[lid.x] : op.init; + total = op.simd_reduce(total); } - // Reduction within thread group - threadgroup_barrier(mem_flags::mem_threadgroup); - total_val = lid < simd_per_group ? local_vals[lid] : op.init; - total_val = op.simd_reduce(total_val); - - // Reduction across threadgroups - if (lid == 0) { - op.atomic_update(out, total_val); - } -} - -template -[[kernel]] void all_reduce_no_atomics( - const device T* in [[buffer(0)]], - device U* out [[buffer(1)]], - const device size_t& in_size [[buffer(2)]], - uint gid [[thread_position_in_grid]], - uint lid [[thread_position_in_threadgroup]], - uint grid_size [[threads_per_grid]], - uint simd_per_group [[simdgroups_per_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint thread_group_id [[threadgroup_position_in_grid]]) { - Op op; - threadgroup U local_vals[simd_size]; - - U total_val = - per_thread_all_reduce(in, in_size, gid, grid_size); - - // Reduction within simd group (simd_add isn't supported for uint64/int64 - // types) - for (uint16_t lane_offset = simd_size / 2; lane_offset > 0; - lane_offset /= 2) { - total_val = op(total_val, simd_shuffle_down(total_val, lane_offset)); - } - // Write simd group reduction results to local memory - if (simd_lane_id == 0) { - local_vals[simd_group_id] = total_val; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Reduction of simdgroup reduction results within threadgroup. - total_val = lid < simd_per_group ? local_vals[lid] : op.init; - for (uint16_t lane_offset = simd_size / 2; lane_offset > 0; - lane_offset /= 2) { - total_val = op(total_val, simd_shuffle_down(total_val, lane_offset)); - } - - // Reduction across threadgroups - if (lid == 0) { - out[thread_group_id] = total_val; + if (lid.x == 0) { + out[gid.y] = total; } } diff --git a/mlx/backend/metal/kernels/reduction/reduce_col.h b/mlx/backend/metal/kernels/reduction/reduce_col.h index 568d0cebc..2d102911a 100644 --- a/mlx/backend/metal/kernels/reduction/reduce_col.h +++ b/mlx/backend/metal/kernels/reduction/reduce_col.h @@ -1,165 +1,298 @@ // Copyright © 2023-2024 Apple Inc. -/////////////////////////////////////////////////////////////////////////////// -// Small column reduce kernel -/////////////////////////////////////////////////////////////////////////////// - -template +template < + typename T, + typename U, + typename Op, + int NDIMS = 0, + int N_READS = REDUCE_N_READS> [[kernel]] void col_reduce_small( const device T* in [[buffer(0)]], device U* out [[buffer(1)]], const constant size_t& reduction_size [[buffer(2)]], const constant size_t& reduction_stride [[buffer(3)]], - const constant size_t& out_size [[buffer(4)]], - const constant int* shape [[buffer(5)]], - const constant size_t* strides [[buffer(6)]], - const constant int& ndim [[buffer(7)]], - const constant size_t& non_col_reductions [[buffer(8)]], - const constant int* non_col_shapes [[buffer(9)]], - const constant size_t* non_col_strides [[buffer(10)]], - const constant int& non_col_ndim [[buffer(11)]], - uint tid [[thread_position_in_grid]]) { - // Appease the compiler - (void)out_size; - + const constant int* shape [[buffer(4)]], + const constant size_t* strides [[buffer(5)]], + const constant int& ndim [[buffer(6)]], + const constant int* reduce_shape [[buffer(7)]], + const constant size_t* reduce_strides [[buffer(8)]], + const constant int& reduce_ndim [[buffer(9)]], + const constant size_t& non_col_reductions [[buffer(10)]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[thread_position_in_grid]], + uint3 tsize [[threads_per_grid]]) { Op op; - U total_val = Op::init; + looped_elem_to_loc loop; + const device T* row; - auto out_idx = tid; + // Case 1: + // reduction_stride is small, reduction_size is small and non_col_reductions + // is small. Each thread computes reduction_stride outputs. + if (reduction_size * non_col_reductions < 64) { + U totals[31]; + for (int i = 0; i < 31; i++) { + totals[i] = Op::init; + } - in += elem_to_loc( - out_idx, - shape + non_col_ndim, - strides + non_col_ndim, - ndim - non_col_ndim); + short stride = reduction_stride; + short size = reduction_size; + short blocks = stride / N_READS; + short extra = stride - blocks * N_READS; - for (uint i = 0; i < non_col_reductions; i++) { - size_t in_idx = - elem_to_loc(i, non_col_shapes, non_col_strides, non_col_ndim); + size_t out_idx = tid.x + tsize.y * size_t(tid.y); + in += elem_to_loc(out_idx, shape, strides, ndim); - for (uint j = 0; j < reduction_size; j++, in_idx += reduction_stride) { - U val = static_cast(in[in_idx]); - total_val = op(total_val, val); + for (uint r = 0; r < non_col_reductions; r++) { + row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); + + for (short i = 0; i < size; i++) { + for (short j = 0; j < blocks; j++) { + for (short k = 0; k < N_READS; k++) { + totals[j * N_READS + k] = + op(totals[j * N_READS + k], + static_cast(row[i * stride + j * N_READS + k])); + } + } + for (short k = 0; k < extra; k++) { + totals[blocks * N_READS + k] = + op(totals[blocks * N_READS + k], + static_cast(row[i * stride + blocks * N_READS + k])); + } + } + + loop.next(reduce_shape, reduce_strides); + } + out += out_idx * reduction_stride; + for (short j = 0; j < stride; j++) { + out[j] = totals[j]; } } - out[out_idx] = total_val; -} - -/////////////////////////////////////////////////////////////////////////////// -// Column reduce helper -/////////////////////////////////////////////////////////////////////////////// - -template -METAL_FUNC U _contiguous_strided_reduce( - const device T* in, - threadgroup U* local_data, - uint in_idx, - uint reduction_size, - uint reduction_stride, - uint2 tid, - uint2 lid, - uint2 lsize) { - Op op; - U total_val = Op::init; - - uint base_offset = (tid.y * lsize.y + lid.y) * N_READS; - for (uint r = 0; r < N_READS && (base_offset + r) < reduction_size; r++) { - uint offset = base_offset + r; - total_val = - op(static_cast(total_val), in[in_idx + offset * reduction_stride]); - } - local_data[lsize.y * lid.x + lid.y] = total_val; - threadgroup_barrier(mem_flags::mem_threadgroup); - - U val = Op::init; - if (lid.y == 0) { - // Perform reduction across columns in thread group - for (uint i = 0; i < lsize.y; i++) { - val = op(val, local_data[lsize.y * lid.x + i]); + // Case 2: + // Reduction stride is small but everything else can be big. We loop both + // across reduction size and non_col_reductions. Each simdgroup produces + // N_READS outputs. + else { + threadgroup U shared_vals[1024]; + U totals[N_READS]; + for (int i = 0; i < N_READS; i++) { + totals[i] = Op::init; } - } - return val; -} + short stride = reduction_stride; + short lid = simd_group_id * simd_size + simd_lane_id; + short2 tile((stride + N_READS - 1) / N_READS, 32); + short2 offset((lid % tile.x) * N_READS, lid / tile.x); + short sm_stride = tile.x * N_READS; + bool safe = offset.x + N_READS <= stride; -/////////////////////////////////////////////////////////////////////////////// -// Column reduce kernel -/////////////////////////////////////////////////////////////////////////////// + size_t out_idx = gid.y + gsize.y * size_t(gid.z); + in += elem_to_loc(out_idx, shape, strides, ndim) + offset.x; -template -[[kernel]] void col_reduce_general( - const device T* in [[buffer(0)]], - device mlx_atomic* out [[buffer(1)]], - const constant size_t& reduction_size [[buffer(2)]], - const constant size_t& reduction_stride [[buffer(3)]], - const constant size_t& out_size [[buffer(4)]], - const constant int* shape [[buffer(5)]], - const constant size_t* strides [[buffer(6)]], - const constant int& ndim [[buffer(7)]], - threadgroup U* local_data [[threadgroup(0)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint3 lsize [[threads_per_threadgroup]]) { - auto out_idx = tid.x * lsize.x + lid.x; - auto in_idx = elem_to_loc(out_idx + tid.z * out_size, shape, strides, ndim); + // Read cooperatively and contiguously and aggregate the partial results. + size_t total = non_col_reductions * reduction_size; + loop.next(offset.y, reduce_shape, reduce_strides); + for (size_t r = offset.y; r < total; r += simd_size) { + row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); - Op op; - if (out_idx < out_size) { - U val = _contiguous_strided_reduce( - in, - local_data, - in_idx, - reduction_size, - reduction_stride, - tid.xy, - lid.xy, - lsize.xy); + if (safe) { + for (int i = 0; i < N_READS; i++) { + totals[i] = op(static_cast(row[i]), totals[i]); + } + } else { + U vals[N_READS]; + for (int i = 0; i < N_READS; i++) { + vals[i] = (offset.x + i < stride) ? static_cast(row[i]) : op.init; + } + for (int i = 0; i < N_READS; i++) { + totals[i] = op(vals[i], totals[i]); + } + } - // Write out reduction results generated by threadgroups working on specific - // output element, contiguously. - if (lid.y == 0) { - op.atomic_update(out, val, out_idx); + loop.next(simd_size, reduce_shape, reduce_strides); + } + + // Each thread holds N_READS partial results but the simdgroups are not + // aligned to do the reduction across the simdgroup so we write our results + // in the shared memory and read them back according to the simdgroup. + for (int i = 0; i < N_READS; i++) { + shared_vals[offset.y * sm_stride + offset.x + i] = totals[i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + for (int i = 0; i < N_READS; i++) { + totals[i] = op.simd_reduce( + shared_vals[simd_lane_id * sm_stride + simd_group_id * N_READS + i]); + } + + // Write the output. + if (simd_lane_id == 0) { + short column = simd_group_id * N_READS; + out += out_idx * reduction_stride + column; + if (column + N_READS <= stride) { + for (int i = 0; i < N_READS; i++) { + out[i] = totals[i]; + } + } else { + for (int i = 0; column + i < stride; i++) { + out[i] = totals[i]; + } + } } } } -template -[[kernel]] void col_reduce_general_no_atomics( +/** + * Our approach is the following simple looped approach: + * 1. Each thread keeps running totals for BN / n_simdgroups outputs. + * 2. Load a tile BM, BN in shared memory. + * 3. Add the values from shared memory to the current running totals. + * Neighboring threads access different rows (transposed acces). + * 4. Move ahead to the next tile until the M axis is exhausted. + * 5. Move ahead to the next non column reduction + * 6. Simd reduce the running totals + * 7. Write them to the output + * + * The kernel becomes verbose because we support all kinds of OOB checks. For + * instance if we choose that reduction_stride must be larger than BN then we + * can get rid of half the kernel. + */ +template < + typename T, + typename U, + typename Op, + int NDIMS = 0, + int BM = 8, + int BN = 128> +[[kernel]] void col_reduce_looped( const device T* in [[buffer(0)]], device U* out [[buffer(1)]], const constant size_t& reduction_size [[buffer(2)]], const constant size_t& reduction_stride [[buffer(3)]], - const constant size_t& out_size [[buffer(4)]], - const constant int* shape [[buffer(5)]], - const constant size_t* strides [[buffer(6)]], - const constant int& ndim [[buffer(7)]], - threadgroup U* local_data [[threadgroup(0)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint3 gid [[thread_position_in_grid]], - uint3 lsize [[threads_per_threadgroup]], - uint3 gsize [[threads_per_grid]]) { - auto out_idx = tid.x * lsize.x + lid.x; - auto in_idx = elem_to_loc(out_idx + tid.z * out_size, shape, strides, ndim); + const constant int* shape [[buffer(4)]], + const constant size_t* strides [[buffer(5)]], + const constant int& ndim [[buffer(6)]], + const constant int* reduce_shape [[buffer(7)]], + const constant size_t* reduce_strides [[buffer(8)]], + const constant int& reduce_ndim [[buffer(9)]], + const constant size_t& non_col_reductions [[buffer(10)]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + Op op; + constexpr int n_simdgroups = 4; + constexpr short tgp_size = n_simdgroups * simd_size; + constexpr short n_reads = (BM * BN) / tgp_size; + constexpr short n_read_blocks = BN / n_reads; - if (out_idx < out_size) { - U val = _contiguous_strided_reduce( - in, - local_data, - in_idx, - reduction_size, - reduction_stride, - tid.xy, - lid.xy, - lsize.xy); + threadgroup U shared_vals[BN * BM]; + U totals[n_reads]; + looped_elem_to_loc loop; + const device T* row; - // Write out reduction results generated by threadgroups working on specific - // output element, contiguously. - if (lid.y == 0) { - uint tgsize_y = ceildiv(gsize.y, lsize.y); - uint tgsize_z = ceildiv(gsize.z, lsize.z); - out[tgsize_y * tgsize_z * gid.x + tgsize_y * tid.z + tid.y] = val; + for (int i = 0; i < n_reads; i++) { + totals[i] = Op::init; + } + + short lid = simd_group_id * simd_size + simd_lane_id; + short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks); + size_t column = BN * gid.x + offset.x; + bool safe = column + n_reads <= reduction_stride; + + size_t out_idx = gid.y + gsize.y * size_t(gid.z); + size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim); + in += in_idx + column; + + size_t total = non_col_reductions * reduction_size; + loop.next(offset.y, reduce_shape, reduce_strides); + for (size_t r = offset.y; r < total; r += BM) { + row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); + + if (safe) { + for (int i = 0; i < n_reads; i++) { + totals[i] = op(static_cast(row[i]), totals[i]); + } + } else { + U vals[n_reads]; + for (int i = 0; i < n_reads; i++) { + vals[i] = + (column + i < reduction_stride) ? static_cast(row[i]) : op.init; + } + for (int i = 0; i < n_reads; i++) { + totals[i] = op(vals[i], totals[i]); + } + } + + loop.next(BM, reduce_shape, reduce_strides); + } + + // We can use a simd reduction to accumulate across BM so each thread writes + // the partial output to SM and then each simdgroup does BN / n_simdgroups + // accumulations. + if (BM == 32) { + constexpr int n_outputs = BN / n_simdgroups; + static_assert( + BM != 32 || n_outputs == n_reads, + "The tile should be selected such that n_outputs == n_reads"); + for (int i = 0; i < n_reads; i++) { + shared_vals[offset.y * BN + offset.x + i] = totals[i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + short2 out_offset(simd_group_id * n_outputs, simd_lane_id); + for (int i = 0; i < n_outputs; i++) { + totals[i] = + op.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]); + } + + // Write the output. + if (simd_lane_id == 0) { + size_t out_column = BN * gid.x + out_offset.x; + out += out_idx * reduction_stride + out_column; + if (out_column + n_outputs <= reduction_stride) { + for (int i = 0; i < n_outputs; i++) { + out[i] = totals[i]; + } + } else { + for (int i = 0; out_column + i < reduction_stride; i++) { + out[i] = totals[i]; + } + } + } + } + + // Each thread holds n_reads partial results. We write them all out to shared + // memory and threads with offset.y == 0 aggregate the columns and write the + // outputs. + else { + short x_block = offset.x / n_reads; + for (int i = 0; i < n_reads; i++) { + shared_vals[x_block * BM * n_reads + i * BM + offset.y] = totals[i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (offset.y == 0) { + for (int i = 0; i < n_reads; i++) { + for (int j = 1; j < BM; j++) { + totals[i] = + op(shared_vals[x_block * BM * n_reads + i * BM + j], totals[i]); + } + } + } + + // Write the output. + if (offset.y == 0) { + out += out_idx * reduction_stride + column; + if (safe) { + for (int i = 0; i < n_reads; i++) { + out[i] = totals[i]; + } + } else { + for (int i = 0; column + i < reduction_stride; i++) { + out[i] = totals[i]; + } + } } } } diff --git a/mlx/backend/metal/kernels/reduction/reduce_row.h b/mlx/backend/metal/kernels/reduction/reduce_row.h index cefaf63e9..903265295 100644 --- a/mlx/backend/metal/kernels/reduction/reduce_row.h +++ b/mlx/backend/metal/kernels/reduction/reduce_row.h @@ -1,287 +1,366 @@ // Copyright © 2023-2024 Apple Inc. -/////////////////////////////////////////////////////////////////////////////// -// Small row reductions -/////////////////////////////////////////////////////////////////////////////// +// Row reduction utilities +// - `per_thread_row_reduce` collaborative partial reduction in the threadgroup +// - `threadgroup_reduce` collaborative reduction in the threadgroup such that +// lid.x == 0 holds the reduced value +// - `thread_reduce` simple loop and reduce the row -// Each thread reduces for one output -template -[[kernel]] void row_reduce_general_small( - const device T* in [[buffer(0)]], - device U* out [[buffer(1)]], - const constant size_t& reduction_size [[buffer(2)]], - const constant size_t& out_size [[buffer(3)]], - const constant size_t& non_row_reductions [[buffer(4)]], - const constant int* shape [[buffer(5)]], - const constant size_t* strides [[buffer(6)]], - const constant int& ndim [[buffer(7)]], - uint lid [[thread_position_in_grid]]) { +/** + * The thread group collaboratively reduces across the rows with bounds + * checking. In the end each thread holds a part of the reduction. + */ +template < + typename T, + typename U, + typename Op, + int N_READS = REDUCE_N_READS, + int N_WRITES = REDUCE_N_WRITES> +METAL_FUNC void per_thread_row_reduce( + thread U totals[N_WRITES], + const device T* inputs[N_WRITES], + int blocks, + int extra, + uint lsize_x, + uint lid_x) { Op op; - uint out_idx = lid; - - if (out_idx >= out_size) { - return; + // Set up the accumulator registers + for (int i = 0; i < N_WRITES; i++) { + totals[i] = Op::init; } - U total_val = Op::init; + // Loop over the reduction size within thread group + for (int i = 0; i < blocks; i++) { + for (int j = 0; j < N_WRITES; j++) { + for (int i = 0; i < N_READS; i++) { + totals[j] = op(static_cast(inputs[j][i]), totals[j]); + } - for (short r = 0; r < short(non_row_reductions); r++) { - uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim); - const device T* in_row = in + in_idx; - - for (short i = 0; i < short(reduction_size); i++) { - total_val = op(static_cast(in_row[i]), total_val); + inputs[j] += lsize_x * N_READS; } } - out[out_idx] = total_val; -} - -// Each simdgroup reduces for one output -template -[[kernel]] void row_reduce_general_med( - const device T* in [[buffer(0)]], - device U* out [[buffer(1)]], - const constant size_t& reduction_size [[buffer(2)]], - const constant size_t& out_size [[buffer(3)]], - const constant size_t& non_row_reductions [[buffer(4)]], - const constant int* shape [[buffer(5)]], - const constant size_t* strides [[buffer(6)]], - const constant int& ndim [[buffer(7)]], - uint tid [[threadgroup_position_in_grid]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_per_group [[dispatch_simdgroups_per_threadgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - Op op; - - uint out_idx = simd_per_group * tid + simd_group_id; - - if (out_idx >= out_size) { - return; - } - - U total_val = Op::init; - - if (short(non_row_reductions) == 1) { - uint in_idx = elem_to_loc(out_idx, shape, strides, ndim); - const device T* in_row = in + in_idx; - - for (short i = simd_lane_id; i < short(reduction_size); i += 32) { - total_val = op(static_cast(in_row[i]), total_val); - } - } - - else if (short(non_row_reductions) >= 32) { - for (short r = simd_lane_id; r < short(non_row_reductions); r += 32) { - uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim); - const device T* in_row = in + in_idx; - - for (short i = 0; i < short(reduction_size); i++) { - total_val = op(static_cast(in_row[i]), total_val); + // Separate case for the last set as we close the reduction size + int index = lid_x * N_READS; + if (index + N_READS <= extra) { + for (int j = 0; j < N_WRITES; j++) { + for (int i = 0; i < N_READS; i++) { + totals[j] = op(static_cast(inputs[j][i]), totals[j]); } } - - } - - else { - const short n_reductions = - short(reduction_size) * short(non_row_reductions); - const short reductions_per_thread = - (n_reductions + simd_size - 1) / simd_size; - - const short r_st = simd_lane_id / reductions_per_thread; - const short r_ed = short(non_row_reductions); - const short r_jump = simd_size / reductions_per_thread; - - const short i_st = simd_lane_id % reductions_per_thread; - const short i_ed = short(reduction_size); - const short i_jump = reductions_per_thread; - - if (r_st < r_jump) { - for (short r = r_st; r < r_ed; r += r_jump) { - uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim); - const device T* in_row = in + in_idx; - - for (short i = i_st; i < i_ed; i += i_jump) { - total_val = op(static_cast(in_row[i]), total_val); - } + } else { + for (int j = 0; j < N_WRITES; j++) { + for (int i = 0; index + i < extra; i++) { + totals[j] = op(static_cast(inputs[j][i]), totals[j]); } } } - - total_val = op.simd_reduce(total_val); - - if (simd_lane_id == 0) { - out[out_idx] = total_val; - } } -/////////////////////////////////////////////////////////////////////////////// -// Large row reductions -/////////////////////////////////////////////////////////////////////////////// - -template -METAL_FUNC U per_thread_row_reduce( +/** + * Consecutive rows in a contiguous array. + */ +template < + typename T, + typename U, + typename Op, + int N_READS = REDUCE_N_READS, + int N_WRITES = REDUCE_N_WRITES> +METAL_FUNC void per_thread_row_reduce( + thread U totals[N_WRITES], const device T* in, const constant size_t& reduction_size, - const constant size_t& out_size, + int blocks, + int extra, + uint lsize_x, + uint lid_x) { + // Set up the input pointers + const device T* inputs[N_WRITES]; + inputs[0] = in + lid_x * N_READS; + for (int i = 1; i < N_READS; i++) { + inputs[i] = inputs[i - 1] + reduction_size; + } + + per_thread_row_reduce( + totals, inputs, blocks, extra, lsize_x, lid_x); +} + +/** + * Consecutive rows in an arbitrarily ordered array. + */ +template < + typename T, + typename U, + typename Op, + int N_READS = REDUCE_N_READS, + int N_WRITES = REDUCE_N_WRITES> +METAL_FUNC void per_thread_row_reduce( + thread U totals[N_WRITES], + const device T* in, + const size_t row_idx, + int blocks, + int extra, const constant int* shape, const constant size_t* strides, const constant int& ndim, uint lsize_x, - uint lid_x, - uint2 tid) { - Op op; - - // Each threadgroup handles 1 reduction - // TODO: Specializing elem_to_loc would be slightly faster - int idx = tid.y * out_size + tid.x; - int extra_offset = elem_to_loc(idx, shape, strides, ndim); - in += extra_offset + lid_x * N_READS; - - // The reduction is accumulated here - U total_val = Op::init; - - // Loop over the reduction size within thread group - int r = 0; - for (; r < (int)ceildiv(reduction_size, N_READS * lsize_x) - 1; r++) { - T vals[N_READS]; - for (int i = 0; i < N_READS; i++) { - vals[i] = in[i]; - } - for (int i = 0; i < N_READS; i++) { - total_val = op(static_cast(vals[i]), total_val); - } - - in += lsize_x * N_READS; + uint lid_x) { + // Set up the input pointers + const device T* inputs[N_WRITES]; + in += lid_x * N_READS; + for (int i = 0; i < N_READS; i++) { + inputs[i] = in + elem_to_loc(row_idx + i, shape, strides, ndim); } - // Separate case for the last set as we close the reduction size - size_t reduction_index = (lid_x + (size_t)lsize_x * r) * N_READS; - if (reduction_index < reduction_size) { - int max_reads = reduction_size - reduction_index; - - T vals[N_READS]; - for (int i = 0; i < N_READS; i++) { - int idx = min(i, max_reads - 1); - vals[i] = static_cast(in[idx]); - } - for (int i = 0; i < N_READS; i++) { - T val = i < max_reads ? vals[i] : Op::init; - total_val = op(static_cast(val), total_val); - } - } - - return total_val; + per_thread_row_reduce( + totals, inputs, blocks, extra, lsize_x, lid_x); } -template -[[kernel]] void row_reduce_general( - const device T* in [[buffer(0)]], - device mlx_atomic* out [[buffer(1)]], - const constant size_t& reduction_size [[buffer(2)]], - const constant size_t& out_size [[buffer(3)]], - const constant size_t& non_row_reductions [[buffer(4)]], - const constant int* shape [[buffer(5)]], - const constant size_t* strides [[buffer(6)]], - const constant int& ndim [[buffer(7)]], +/** + * Reduce within the threadgroup. + */ +template < + typename T, + typename U, + typename Op, + int N_READS = REDUCE_N_READS, + int N_WRITES = REDUCE_N_WRITES> +METAL_FUNC void threadgroup_reduce( + thread U totals[N_WRITES], + threadgroup U* shared_vals, uint3 lid [[thread_position_in_threadgroup]], - uint3 lsize [[threads_per_threadgroup]], - uint3 tid [[threadgroup_position_in_grid]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_per_group [[simdgroups_per_threadgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - (void)non_row_reductions; - Op op; - threadgroup U local_vals[simd_size]; - U total_val = per_thread_row_reduce( - in, - reduction_size, - out_size, - shape, - strides, - ndim, - lsize.x, - lid.x, - tid.xy); - - total_val = op.simd_reduce(total_val); - - // Prepare next level - if (simd_lane_id == 0) { - local_vals[simd_group_id] = total_val; + // Simdgroup first + for (int i = 0; i < N_WRITES; i++) { + totals[i] = op.simd_reduce(totals[i]); } - threadgroup_barrier(mem_flags::mem_threadgroup); - // Reduction within thread group - // Only needed if multiple simd groups - if (reduction_size > simd_size) { - total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init; - total_val = op.simd_reduce(total_val); - } - // Update output - if (lid.x == 0) { - op.atomic_update(out, total_val, tid.x); + // Across simdgroups + if (simd_per_group > 1) { + if (simd_lane_id == 0) { + for (int i = 0; i < N_WRITES; i++) { + shared_vals[simd_group_id * N_WRITES + i] = totals[i]; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + U values[N_WRITES]; + for (int i = 0; i < N_WRITES; i++) { + values[i] = (lid.x < simd_per_group) ? shared_vals[lid.x * N_WRITES + i] + : op.init; + } + + for (int i = 0; i < N_WRITES; i++) { + totals[i] = op.simd_reduce(values[i]); + } } } template -[[kernel]] void row_reduce_general_no_atomics( +METAL_FUNC void +thread_reduce(thread U& total, const device T* row, int blocks, int extra) { + Op op; + for (int i = 0; i < blocks; i++) { + U vals[N_READS]; + for (int j = 0; j < N_READS; j++) { + vals[j] = row[j]; + } + for (int j = 0; j < N_READS; j++) { + total = op(vals[j], total); + } + row += N_READS; + } + for (int i = 0; i < extra; i++) { + total = op(*row++, total); + } +} + +// Reduction kernels +// - `row_reduce_small` depending on the non-row reductions and row size it +// either just loops over everything or a simd collaboratively reduces the +// non_row reductions. In the first case one thread is responsible for one +// output on the 2nd one simd is responsible for one output. +// - `row_reduce_simple` simple contiguous row reduction +// - `row_reduce_looped` simply loop and reduce each row for each non-row +// reduction. One threadgroup is responsible for one output. + +template < + typename T, + typename U, + typename Op, + int NDIMS = 0, + int N_READS = REDUCE_N_READS> +[[kernel]] void row_reduce_small( + const device T* in [[buffer(0)]], + device U* out [[buffer(1)]], + const constant size_t& row_size [[buffer(2)]], + const constant size_t& non_row_reductions [[buffer(3)]], + const constant int* shape [[buffer(4)]], + const constant size_t* strides [[buffer(5)]], + const constant int& ndim [[buffer(6)]], + const constant int* reduce_shape [[buffer(7)]], + const constant size_t* reduce_strides [[buffer(8)]], + const constant int& reduce_ndim [[buffer(9)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], + uint3 tid [[thread_position_in_grid]], + uint3 tsize [[threads_per_grid]]) { + Op op; + + U total_val = Op::init; + looped_elem_to_loc loop; + + // Precompute some row reduction numbers + const device T* row; + int blocks = row_size / N_READS; + int extra = row_size % N_READS; + + if ((non_row_reductions < 32 && row_size <= 8) || non_row_reductions <= 8) { + // Simple loop over non_row_reductions and reduce the row in the thread. + size_t out_idx = tid.x + tsize.y * size_t(tid.y); + in += elem_to_loc(out_idx, shape, strides, ndim); + + for (uint r = 0; r < non_row_reductions; r++) { + row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); + thread_reduce(total_val, row, blocks, extra); + loop.next(reduce_shape, reduce_strides); + } + + out[out_idx] = total_val; + } else { + // Collaboratively reduce over non_row_reductions in the simdgroup. Each + // thread reduces every 32nd row and then a simple simd reduce. + size_t out_idx = gid.y + gsize.y * size_t(gid.z); + in += elem_to_loc(out_idx, shape, strides, ndim); + + loop.next(simd_lane_id, reduce_shape, reduce_strides); + + for (uint r = simd_lane_id; r < non_row_reductions; r += simd_size) { + row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); + thread_reduce(total_val, row, blocks, extra); + loop.next(simd_size, reduce_shape, reduce_strides); + } + + total_val = op.simd_reduce(total_val); + + if (simd_lane_id == 0) { + out[out_idx] = total_val; + } + } +} + +template < + typename T, + typename U, + typename Op, + int N_READS = REDUCE_N_READS, + int N_WRITES = REDUCE_N_WRITES> +[[kernel]] void row_reduce_simple( const device T* in [[buffer(0)]], device U* out [[buffer(1)]], const constant size_t& reduction_size [[buffer(2)]], const constant size_t& out_size [[buffer(3)]], - const constant size_t& non_row_reductions [[buffer(4)]], - const constant int* shape [[buffer(5)]], - const constant size_t* strides [[buffer(6)]], - const constant int& ndim [[buffer(7)]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], uint3 lid [[thread_position_in_threadgroup]], uint3 lsize [[threads_per_threadgroup]], - uint3 gsize [[threads_per_grid]], - uint3 tid [[threadgroup_position_in_grid]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_per_group [[simdgroups_per_threadgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - (void)non_row_reductions; + threadgroup U shared_vals[simd_size * N_WRITES]; + U totals[N_WRITES]; - Op op; - - threadgroup U local_vals[simd_size]; - U total_val = per_thread_row_reduce( - in, - reduction_size, - out_size, - shape, - strides, - ndim, - lsize.x, - lid.x, - tid.xy); - - // Reduction within simd group - simd_add isn't supported for int64 types - for (uint16_t i = simd_size / 2; i > 0; i /= 2) { - total_val = op(total_val, simd_shuffle_down(total_val, i)); + // Move to the row + size_t out_idx = N_WRITES * (gid.y + gsize.y * size_t(gid.z)); + if (out_idx + N_WRITES > out_size) { + out_idx = out_size - N_WRITES; } + in += out_idx * reduction_size; + out += out_idx; - // Prepare next level - if (simd_lane_id == 0) { - local_vals[simd_group_id] = total_val; - } - threadgroup_barrier(mem_flags::mem_threadgroup); + // Each thread reduces across the row + int blocks = reduction_size / (lsize.x * N_READS); + int extra = reduction_size - blocks * (lsize.x * N_READS); + per_thread_row_reduce( + totals, in, reduction_size, blocks, extra, lsize.x, lid.x); - // Reduction within thread group - // Only needed if thread group has multiple simd groups - if (ceildiv(reduction_size, N_READS) > simd_size) { - total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init; - for (uint16_t i = simd_size / 2; i > 0; i /= 2) { - total_val = op(total_val, simd_shuffle_down(total_val, i)); + // Reduce across the threadgroup + threadgroup_reduce( + totals, shared_vals, lid, simd_lane_id, simd_per_group, simd_group_id); + + // Write the output + if (lid.x == 0) { + for (int i = 0; i < N_WRITES; i++) { + out[i] = totals[i]; } } - // Write row reduce output for threadgroup with 1st thread in thread group +} + +template < + typename T, + typename U, + typename Op, + int NDIMS = 0, + int N_READS = REDUCE_N_READS> +[[kernel]] void row_reduce_looped( + const device T* in [[buffer(0)]], + device U* out [[buffer(1)]], + const constant size_t& row_size [[buffer(2)]], + const constant size_t& non_row_reductions [[buffer(3)]], + const constant int* shape [[buffer(4)]], + const constant size_t* strides [[buffer(5)]], + const constant int& ndim [[buffer(6)]], + const constant int* reduce_shape [[buffer(7)]], + const constant size_t* reduce_strides [[buffer(8)]], + const constant int& reduce_ndim [[buffer(9)]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 lsize [[threads_per_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_per_group [[simdgroups_per_threadgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + Op op; + threadgroup U shared_vals[simd_size]; + U total = Op::init; + + size_t out_idx = gid.y + gsize.y * size_t(gid.z); + + // lid.x * N_READS breaks the per_thread_row_reduce interface a bit. Maybe it + // needs a small refactor. + in += elem_to_loc(out_idx, shape, strides, ndim) + lid.x * N_READS; + + looped_elem_to_loc loop; + const device T* row; + int blocks = row_size / (lsize.x * N_READS); + int extra = row_size - blocks * (lsize.x * N_READS); + + for (size_t i = 0; i < non_row_reductions; i++) { + row = in + loop.location(i, reduce_shape, reduce_strides, reduce_ndim); + + // Each thread reduces across the row + U row_total; + per_thread_row_reduce( + &row_total, &row, blocks, extra, lsize.x, lid.x); + + // Aggregate across rows + total = op(total, row_total); + + loop.next(reduce_shape, reduce_strides); + } + + // Reduce across the threadgroup + threadgroup_reduce( + &total, shared_vals, lid, simd_lane_id, simd_per_group, simd_group_id); + + // Write the output if (lid.x == 0) { - out[(ceildiv(gsize.y, lsize.y) * tid.x) + tid.y] = total_val; + out[out_idx] = total; } } diff --git a/mlx/backend/metal/kernels/scatter.h b/mlx/backend/metal/kernels/scatter.h index 108e40adc..b4c6f0061 100644 --- a/mlx/backend/metal/kernels/scatter.h +++ b/mlx/backend/metal/kernels/scatter.h @@ -18,7 +18,7 @@ METAL_FUNC void scatter_1d_index_impl( uint2 gid [[thread_position_in_grid]]) { Op op; - uint out_idx = 0; + size_t out_idx = 0; for (int i = 0; i < NIDX; i++) { auto idx_val = offset_neg_idx(idx_buffers[i][gid.y], out_shape[i]); out_idx += idx_val * out_strides[i]; diff --git a/mlx/backend/metal/kernels/utils.h b/mlx/backend/metal/kernels/utils.h index 4b3393fae..17d71b880 100644 --- a/mlx/backend/metal/kernels/utils.h +++ b/mlx/backend/metal/kernels/utils.h @@ -64,6 +64,16 @@ struct Limits { static constexpr constant bool min = false; }; +template <> +struct Limits { + static constexpr constant complex64_t max = complex64_t( + metal::numeric_limits::infinity(), + metal::numeric_limits::infinity()); + static constexpr constant complex64_t min = complex64_t( + -metal::numeric_limits::infinity(), + -metal::numeric_limits::infinity()); +}; + /////////////////////////////////////////////////////////////////////////////// // Indexing utils /////////////////////////////////////////////////////////////////////////////// @@ -101,6 +111,34 @@ METAL_FUNC stride_t elem_to_loc( return loc; } +template +METAL_FUNC stride_t elem_to_loc( + stride_t elem, + device const int* shape, + device const stride_t* strides, + int ndim) { + stride_t loc = 0; + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { + loc += (elem % shape[i]) * strides[i]; + elem /= shape[i]; + } + return loc; +} + +template +METAL_FUNC stride_t elem_to_loc( + stride_t elem, + constant const int* shape, + constant const stride_t* strides, + int ndim) { + stride_t loc = 0; + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { + loc += (elem % shape[i]) * strides[i]; + elem /= shape[i]; + } + return loc; +} + // Non templated version to handle arbitrary dims template METAL_FUNC stride_t elem_to_loc( @@ -288,12 +326,87 @@ METAL_FUNC uint3 elem_to_loc_3_nd( return loc; } +/////////////////////////////////////////////////////////////////////////////// +// Elem to loc in a loop utils +/////////////////////////////////////////////////////////////////////////////// + +template +struct looped_elem_to_loc { + looped_elem_to_loc inner_looper; + offset_t offset{0}; + int index{0}; + + void next(const constant int* shape, const constant size_t* strides) { + index++; + offset += strides[dim - 1]; + + if (index >= shape[dim - 1]) { + index = 0; + inner_looper.next(shape, strides); + offset = inner_looper.offset; + } + } + + void next(int n, const constant int* shape, const constant size_t* strides) { + index += n; + offset += n * strides[dim - 1]; + + if (index >= shape[dim - 1]) { + int extra = index - shape[dim - 1]; + index = 0; + inner_looper.next(shape, strides); + offset = inner_looper.offset; + if (extra > 0) { + next(extra, shape, strides); + } + } + } + + offset_t + location(offset_t, const constant int*, const constant size_t*, int) { + return offset; + } +}; + +template +struct looped_elem_to_loc<1, offset_t> { + offset_t offset{0}; + + void next(const constant int*, const constant size_t* strides) { + offset += strides[0]; + } + + void next(int n, const constant int*, const constant size_t* strides) { + offset += n * strides[0]; + } + + offset_t + location(offset_t, const constant int*, const constant size_t*, int) { + return offset; + } +}; + +template +struct looped_elem_to_loc<0, offset_t> { + void next(const constant int*, const constant size_t*) {} + void next(int, const constant int*, const constant size_t*) {} + + offset_t location( + offset_t idx, + const constant int* shape, + const constant size_t* strides, + int ndim) { + return elem_to_loc(idx, shape, strides, ndim); + } +}; + /////////////////////////////////////////////////////////////////////////////// // Calculation utils /////////////////////////////////////////////////////////////////////////////// /** Compute ceil((float)N/(float)M) */ -inline size_t ceildiv(size_t N, size_t M) { +template +inline T ceildiv(T N, U M) { return (N + M - 1) / M; } @@ -339,3 +452,8 @@ inline int64_t simd_shuffle_down(int64_t data, uint16_t delta) { inline bool simd_shuffle_down(bool data, uint16_t delta) { return simd_shuffle_down(static_cast(data), delta); } + +inline complex64_t simd_shuffle_down(complex64_t data, uint16_t delta) { + return complex64_t( + simd_shuffle_down(data.real, delta), simd_shuffle_down(data.imag, delta)); +} diff --git a/mlx/backend/metal/normalization.cpp b/mlx/backend/metal/normalization.cpp index 5f728cd33..40f433fa6 100644 --- a/mlx/backend/metal/normalization.cpp +++ b/mlx/backend/metal/normalization.cpp @@ -135,8 +135,8 @@ void RMSNormVJP::eval_gpu( auto axis_size = static_cast(x.shape().back()); int n_rows = x.data_size() / axis_size; - // Allocate a temporary to store the gradients for w and initialize the - // gradient accumulator to 0. + // Allocate the gradient accumulator gw and a temporary to store the + // gradients before they are accumulated. array gw_temp({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}); bool g_in_gw = false; if (!g_in_gx && g.is_donatable()) { @@ -146,11 +146,7 @@ void RMSNormVJP::eval_gpu( gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes())); } copies.push_back(gw_temp); - { - array zero(0, gw.dtype()); - copy_gpu(zero, gw, CopyType::Scalar, s); - copies.push_back(std::move(zero)); - } + gw.set_data(allocator::malloc_or_wait(gw.nbytes())); const int simd_size = 32; const int n_reads = RMS_N_READS; @@ -330,8 +326,8 @@ void LayerNormVJP::eval_gpu( auto axis_size = static_cast(x.shape().back()); int n_rows = x.data_size() / axis_size; - // Allocate a temporary to store the gradients for w and initialize the - // gradient accumulator to 0. + // Allocate a temporary to store the gradients for w and allocate the output + // gradient accumulators. array gw_temp({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}); bool g_in_gw = false; if (!g_in_gx && g.is_donatable()) { @@ -341,12 +337,8 @@ void LayerNormVJP::eval_gpu( gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes())); } copies.push_back(gw_temp); - { - array zero(0, gw.dtype()); - copy_gpu(zero, gw, CopyType::Scalar, s); - copy_gpu(zero, gb, CopyType::Scalar, s); - copies.push_back(std::move(zero)); - } + gw.set_data(allocator::malloc_or_wait(gw.nbytes())); + gb.set_data(allocator::malloc_or_wait(gb.nbytes())); // Finish with the gradient for b in case we had a b auto& compute_encoder = d.get_command_encoder(s.index); diff --git a/mlx/backend/metal/reduce.cpp b/mlx/backend/metal/reduce.cpp index 2d110a5d0..92eccf880 100644 --- a/mlx/backend/metal/reduce.cpp +++ b/mlx/backend/metal/reduce.cpp @@ -15,9 +15,168 @@ namespace mlx::core { -////////////////////////////////////////////////////////////////////// -// Case wise reduce dispatch -////////////////////////////////////////////////////////////////////// +namespace { + +struct RowReduceArgs { + // Input shape and strides not including the reduction axes + std::vector shape; + std::vector strides; + int ndim; + + // Input shape and strides for the reduction axes + std::vector reduce_shape; + std::vector reduce_strides; + int reduce_ndim; + + // The number of rows we are reducing. Namely prod(reduce_shape). + size_t non_row_reductions; + + // The size of the row. + size_t row_size; + + RowReduceArgs( + const array& in, + const ReductionPlan& plan, + const std::vector& axes) { + row_size = plan.shape.back(); + + reduce_shape = plan.shape; + reduce_strides = plan.strides; + reduce_shape.pop_back(); + reduce_strides.pop_back(); + reduce_ndim = reduce_shape.size(); + + non_row_reductions = 1; + for (auto s : reduce_shape) { + non_row_reductions *= s; + } + + std::tie(shape, strides) = shapes_without_reduction_axes(in, axes); + std::tie(shape, strides) = collapse_contiguous_dims(shape, strides); + ndim = shape.size(); + } + + void encode(CommandEncoder& compute_encoder) { + // Push 0s to avoid encoding empty vectors. + if (reduce_ndim == 0) { + reduce_shape.push_back(0); + reduce_strides.push_back(0); + } + if (ndim == 0) { + shape.push_back(0); + strides.push_back(0); + } + + compute_encoder->setBytes(&row_size, sizeof(size_t), 2); + compute_encoder->setBytes(&non_row_reductions, sizeof(size_t), 3); + compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4); + compute_encoder->setBytes( + strides.data(), strides.size() * sizeof(size_t), 5); + compute_encoder->setBytes(&ndim, sizeof(int), 6); + compute_encoder->setBytes( + reduce_shape.data(), reduce_shape.size() * sizeof(int), 7); + compute_encoder->setBytes( + reduce_strides.data(), reduce_strides.size() * sizeof(size_t), 8); + compute_encoder->setBytes(&reduce_ndim, sizeof(int), 9); + + if (reduce_ndim == 0) { + reduce_shape.pop_back(); + reduce_strides.pop_back(); + } + if (ndim == 0) { + shape.pop_back(); + strides.pop_back(); + } + } +}; + +struct ColReduceArgs { + // Input shape and strides not including the reduction axes + std::vector shape; + std::vector strides; + int ndim; + + // Input shape and strides for the reduction axes + std::vector reduce_shape; + std::vector reduce_strides; + int reduce_ndim; + + // The number of column reductions we are doing. Namely prod(reduce_shape). + size_t non_col_reductions; + + // The size of the contiguous column reduction. + size_t reduction_size; + size_t reduction_stride; + + ColReduceArgs( + const array& in, + const ReductionPlan& plan, + const std::vector& axes) { + reduction_size = plan.shape.back(); + reduction_stride = plan.strides.back(); + + reduce_shape = plan.shape; + reduce_strides = plan.strides; + reduce_shape.pop_back(); + reduce_strides.pop_back(); + reduce_ndim = reduce_shape.size(); + + non_col_reductions = 1; + for (auto s : reduce_shape) { + non_col_reductions *= s; + } + + // We 'll use a stride_back variable because strides.back() could be 0 but + // yet we may have removed the appropriate amount of elements. It is safe + // to compute the stride by multiplying shapes (while < reduction_stride) + // because it is a contiguous section. + size_t stride_back = 1; + std::tie(shape, strides) = shapes_without_reduction_axes(in, axes); + while (!shape.empty() && stride_back < reduction_stride) { + stride_back *= shape.back(); + shape.pop_back(); + strides.pop_back(); + } + std::tie(shape, strides) = collapse_contiguous_dims(shape, strides); + ndim = shape.size(); + } + + void encode(CommandEncoder& compute_encoder) { + // Push 0s to avoid encoding empty vectors. + if (reduce_ndim == 0) { + reduce_shape.push_back(0); + reduce_strides.push_back(0); + } + if (ndim == 0) { + shape.push_back(0); + strides.push_back(0); + } + + compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); + compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3); + compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4); + compute_encoder->setBytes( + strides.data(), strides.size() * sizeof(size_t), 5); + compute_encoder->setBytes(&ndim, sizeof(int), 6); + compute_encoder->setBytes( + reduce_shape.data(), reduce_shape.size() * sizeof(int), 7); + compute_encoder->setBytes( + reduce_strides.data(), reduce_strides.size() * sizeof(size_t), 8); + compute_encoder->setBytes(&reduce_ndim, sizeof(int), 9); + compute_encoder->setBytes(&non_col_reductions, sizeof(size_t), 10); + + if (reduce_ndim == 0) { + reduce_shape.pop_back(); + reduce_strides.pop_back(); + } + if (ndim == 0) { + shape.pop_back(); + strides.pop_back(); + } + } +}; + +} // namespace inline auto safe_div(size_t n, size_t m) { return m == 0 ? 0 : (n + m - 1) / m; @@ -31,96 +190,237 @@ inline bool is_64b_int(Dtype dtype) { return dtype == int64 || dtype == uint64; } -// All Reduce +inline bool is_64b_dtype(Dtype dtype) { + return dtype == int64 || dtype == uint64 || dtype == complex64; +} + +inline int threadgroup_size_from_row_size(int row_size) { + // 1 simdgroup per row smallish rows + if (row_size <= 512) { + return 32; + } + + // 2 simdgroups per row for medium rows + if (row_size <= 1024) { + return 64; + } + + // up to 32 simdgroups after that + int thread_group_size; + thread_group_size = (row_size + REDUCE_N_READS - 1) / REDUCE_N_READS; + thread_group_size = ((thread_group_size + 31) / 32) * 32; + thread_group_size = std::min(1024, thread_group_size); + return thread_group_size; +} + +inline auto output_grid_for_col_reduce( + const array& out, + const ColReduceArgs& args) { + auto out_shape = out.shape(); + auto out_strides = out.strides(); + while (!out_shape.empty() && out_strides.back() < args.reduction_stride) { + out_shape.pop_back(); + out_strides.pop_back(); + } + return get_2d_grid_dims(out_shape, out_strides); +} + +void init_reduce( + array& out, + const std::string& op_name, + CommandEncoder& compute_encoder, + metal::Device& d, + const Stream& s) { + auto kernel = + get_reduce_init_kernel(d, "i_reduce_" + op_name + type_to_name(out), out); + size_t nthreads = out.size(); + MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); + NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); + if (thread_group_size > nthreads) { + thread_group_size = nthreads; + } + MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); + compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_output_array(out, 0); + compute_encoder.dispatchThreads(grid_dims, group_dims); +} + void all_reduce_dispatch( const array& in, array& out, const std::string& op_name, CommandEncoder& compute_encoder, metal::Device& d, - const Stream& s) { - Dtype out_dtype = out.dtype(); - bool is_out_64b_int = is_64b_int(out_dtype); - std::string kernel_name = "all"; - if (is_out_64b_int) { - kernel_name += "NoAtomics"; - } - kernel_name += "_reduce_" + op_name + type_to_name(in); - auto kernel = get_reduce_kernel(d, kernel_name, op_name, in, out); - + const Stream& s, + std::vector& copies) { + // Set the kernel + std::ostringstream kname; + kname << "all_reduce_" << op_name << type_to_name(in); + auto kernel = get_reduce_kernel(d, kname.str(), op_name, in, out); compute_encoder->setComputePipelineState(kernel); - // We make sure each thread has enough to do by making it read in - // at least n_reads inputs - int n_reads = REDUCE_N_READS; size_t in_size = in.size(); - // mod_in_size gives us the groups of n_reads needed to go over the entire - // input - uint mod_in_size = (in_size + n_reads - 1) / n_reads; - NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); - thread_group_size = - mod_in_size > thread_group_size ? thread_group_size : mod_in_size; - uint simd_size = kernel->threadExecutionWidth(); - thread_group_size = - ((thread_group_size + simd_size - 1) / simd_size) * simd_size; + // Small array so dispatch a single threadgroup + if (in_size <= REDUCE_N_READS * 1024) { + int threadgroup_size = (in_size + REDUCE_N_READS - 1) / REDUCE_N_READS; + threadgroup_size = ((threadgroup_size + 31) / 32) * 32; + MTL::Size grid_dims(threadgroup_size, 1, 1); - // If the number of thread groups needed exceeds 1024, we reuse threads groups - uint n_thread_groups = safe_div(mod_in_size, thread_group_size); - n_thread_groups = std::min(n_thread_groups, 1024u); - uint nthreads = n_thread_groups * thread_group_size; - - MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); - - // Encode buffers and dispatch - if (is_out_64b_int == false || n_thread_groups == 1) { compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(out, 1); compute_encoder->setBytes(&in_size, sizeof(size_t), 2); - compute_encoder.dispatchThreads(grid_dims, group_dims); + compute_encoder->setBytes(&in_size, sizeof(size_t), 3); + compute_encoder.dispatchThreads(grid_dims, grid_dims); + } - } else { - // Allocate intermediate array to store partial reduction results - size_t intermediate_size = n_thread_groups; - array intermediate = - array({static_cast(intermediate_size)}, out_dtype, nullptr, {}); + // We need multiple threadgroups so we 'll do it in 2 passes. + else { + int n_rows, threadgroup_2nd_pass; + // Less than 2**26 bytes + if (in.nbytes() <= (1 << 26)) { + n_rows = 32 * REDUCE_N_READS; + threadgroup_2nd_pass = 32; + } + + // Really large matrix so parallelize as much as possible + else { + n_rows = 1024 * REDUCE_N_READS; + threadgroup_2nd_pass = 1024; + } + + // Allocate an intermediate tensor to hold results if needed + array intermediate({n_rows}, out.dtype(), nullptr, {}); intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes())); - std::vector intermediates = {intermediate}; + copies.push_back(intermediate); - // First dispatch + // 1st pass + size_t row_size = (in_size + n_rows - 1) / n_rows; + int threadgroup_size = + std::min((row_size + REDUCE_N_READS - 1) / REDUCE_N_READS, 1024ul); + threadgroup_size = ((threadgroup_size + 31) / 32) * 32; + MTL::Size grid_dims(threadgroup_size, n_rows, 1); + MTL::Size group_dims(threadgroup_size, 1, 1); compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(intermediate, 1); compute_encoder->setBytes(&in_size, sizeof(size_t), 2); + compute_encoder->setBytes(&row_size, sizeof(size_t), 3); compute_encoder.dispatchThreads(grid_dims, group_dims); - // Second pass to reduce intermediate reduction results written to DRAM + // 2nd pass + compute_encoder->setComputePipelineState(kernel); + size_t intermediate_size = n_rows; + grid_dims = MTL::Size(threadgroup_2nd_pass, 1, 1); + group_dims = MTL::Size(threadgroup_2nd_pass, 1, 1); compute_encoder.set_input_array(intermediate, 0); compute_encoder.set_output_array(out, 1); compute_encoder->setBytes(&intermediate_size, sizeof(size_t), 2); - - mod_in_size = (intermediate_size + n_reads - 1) / n_reads; - - thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); - thread_group_size = - mod_in_size > thread_group_size ? thread_group_size : mod_in_size; - thread_group_size = - ((thread_group_size + simd_size - 1) / simd_size) * simd_size; - - // If the number of thread groups needed exceeds 1024, we reuse threads - // groups - nthreads = thread_group_size; - group_dims = MTL::Size(thread_group_size, 1, 1); - grid_dims = MTL::Size(nthreads, 1, 1); + compute_encoder->setBytes(&intermediate_size, sizeof(size_t), 3); compute_encoder.dispatchThreads(grid_dims, group_dims); - - d.get_command_buffer(s.index)->addCompletedHandler( - [intermediates](MTL::CommandBuffer*) mutable { - intermediates.clear(); - }); } } +void row_reduce_small( + const array& in, + array& out, + const std::string& op_name, + RowReduceArgs& args, + CommandEncoder& compute_encoder, + metal::Device& d, + const Stream& s) { + // Set the kernel + std::ostringstream kname; + int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0; + kname << "rowSmall" << n << "_reduce_" << op_name << type_to_name(in); + auto kernel = get_reduce_kernel(d, kname.str(), op_name, in, out); + compute_encoder->setComputePipelineState(kernel); + + // Figure out the grid dims + MTL::Size grid_dims; + MTL::Size group_dims; + if ((args.non_row_reductions < 32 && args.row_size <= 8) || + args.non_row_reductions <= 8) { + grid_dims = get_2d_grid_dims(out.shape(), out.strides()); + group_dims = + MTL::Size((grid_dims.width < 1024) ? grid_dims.width : 1024, 1, 1); + } else { + auto out_grid_size = get_2d_grid_dims(out.shape(), out.strides()); + grid_dims = MTL::Size(32, out_grid_size.width, out_grid_size.height); + group_dims = MTL::Size(32, 1, 1); + } + + // Launch + compute_encoder.set_input_array(in, 0); + compute_encoder.set_output_array(out, 1); + args.encode(compute_encoder); + compute_encoder.dispatchThreads(grid_dims, group_dims); +} + +void row_reduce_simple( + const array& in, + array& out, + const std::string& op_name, + RowReduceArgs& args, + CommandEncoder& compute_encoder, + metal::Device& d, + const Stream& s) { + // Set the kernel + std::ostringstream kname; + kname << "rowSimple_reduce_" << op_name << type_to_name(in); + auto kernel = get_reduce_kernel(d, kname.str(), op_name, in, out); + compute_encoder->setComputePipelineState(kernel); + + // Figure out the grid dims + size_t row_size = args.row_size; + size_t out_size = out.size(); + auto out_grid_size = get_2d_grid_dims(out.shape(), out.strides()); + out_grid_size.width = + (out_grid_size.width + REDUCE_N_WRITES - 1) / REDUCE_N_WRITES; + int threadgroup_size = threadgroup_size_from_row_size(row_size); + if (in.itemsize() == 8) { + threadgroup_size = std::min(threadgroup_size, 512); + } + MTL::Size grid_dims( + threadgroup_size, out_grid_size.width, out_grid_size.height); + MTL::Size group_dims(threadgroup_size, 1, 1); + + // Launch + compute_encoder.set_input_array(in, 0); + compute_encoder.set_output_array(out, 1); + compute_encoder->setBytes(&row_size, sizeof(size_t), 2); + compute_encoder->setBytes(&out_size, sizeof(size_t), 3); + compute_encoder.dispatchThreads(grid_dims, group_dims); +} + +void row_reduce_looped( + const array& in, + array& out, + const std::string& op_name, + RowReduceArgs& args, + CommandEncoder& compute_encoder, + metal::Device& d, + const Stream& s) { + // Set the kernel + std::ostringstream kname; + int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0; + kname << "rowLooped" << n << "_reduce_" << op_name << type_to_name(in); + auto kernel = get_reduce_kernel(d, kname.str(), op_name, in, out); + compute_encoder->setComputePipelineState(kernel); + + // Figure out the grid + auto out_grid_size = get_2d_grid_dims(out.shape(), out.strides()); + int threadgroup_size = threadgroup_size_from_row_size(args.row_size); + MTL::Size grid_dims( + threadgroup_size, out_grid_size.width, out_grid_size.height); + MTL::Size group_dims(threadgroup_size, 1, 1); + + // Launch + compute_encoder.set_input_array(in, 0); + compute_encoder.set_output_array(out, 1); + args.encode(compute_encoder); + compute_encoder.dispatchThreads(grid_dims, group_dims); +} + void row_reduce_general_dispatch( const array& in, array& out, @@ -130,175 +430,106 @@ void row_reduce_general_dispatch( CommandEncoder& compute_encoder, metal::Device& d, const Stream& s) { - Dtype out_dtype = out.dtype(); - bool is_out_64b_int = is_64b_int(out_dtype); - // Prepare the arguments for the kernel - size_t reduction_size = plan.shape.back(); - auto shape = plan.shape; - auto strides = plan.strides; + RowReduceArgs args(in, plan, axes); - shape.pop_back(); - strides.pop_back(); + // Case 1: The row is small + if (args.row_size <= 64) { + return row_reduce_small(in, out, op_name, args, compute_encoder, d, s); + } - size_t non_row_reductions = 1; - for (auto s : shape) { - non_row_reductions *= static_cast(s); + // Case 2: Contiguous reduce without non-row reductions + if (plan.type == ContiguousReduce && args.reduce_ndim == 0) { + return row_reduce_simple(in, out, op_name, args, compute_encoder, d, s); } - size_t out_size = out.size(); - auto [rem_shape, rem_strides] = shapes_without_reduction_axes(in, axes); - for (auto s : rem_shape) { - shape.push_back(s); - } - for (auto s : rem_strides) { - strides.push_back(s); - } - int ndim = shape.size(); - // Determine dispatch kernel + // Case 3: General row reduce including non-row reductions + return row_reduce_looped(in, out, op_name, args, compute_encoder, d, s); +} + +void strided_reduce_small( + const array& in, + array& out, + const std::string& op_name, + ColReduceArgs& args, + CommandEncoder& compute_encoder, + metal::Device& d, + const Stream& s) { + // Figure out the grid dims + MTL::Size grid_dims, group_dims; + + // Case 1: everything is small so launch one thread per col reduce + if (args.reduction_size * args.non_col_reductions < 64) { + grid_dims = output_grid_for_col_reduce(out, args); + int threadgroup_size = (grid_dims.width > 128) ? 128 : grid_dims.width; + group_dims = MTL::Size(threadgroup_size, 1, 1); + } + + // Case 2: Reduction in the simdgroup + else { + args.reduce_shape.push_back(args.reduction_size); + args.reduce_strides.push_back(args.reduction_stride); + args.reduce_ndim++; + int simdgroups = + (args.reduction_stride + REDUCE_N_READS - 1) / REDUCE_N_READS; + int threadgroup_size = simdgroups * 32; + auto out_grid_dims = output_grid_for_col_reduce(out, args); + grid_dims = + MTL::Size(threadgroup_size, out_grid_dims.width, out_grid_dims.height); + group_dims = MTL::Size(threadgroup_size, 1, 1); + } + + // Set the kernel + int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0; std::ostringstream kname; - - bool is_small = non_row_reductions * reduction_size < 32; - bool is_med = non_row_reductions * reduction_size <= 256; - is_out_64b_int &= !is_small && !is_med; - - std::string small_desc; - if (is_out_64b_int) { - small_desc = "NoAtomics"; - } else if (is_small) { - small_desc = "Small"; - } else if (is_med) { - small_desc = "Med"; - } else { - small_desc = ""; - } - kname << "rowGeneral" << small_desc << "_reduce_" << op_name - << type_to_name(in); - + kname << "colSmall" << n << "_reduce_" << op_name << type_to_name(in); auto kernel = get_reduce_kernel(d, kname.str(), op_name, in, out); compute_encoder->setComputePipelineState(kernel); - // Get dispatch grid dims - MTL::Size grid_dims; - MTL::Size group_dims; + // Launch + compute_encoder.set_input_array(in, 0); + compute_encoder.set_output_array(out, 1); + args.encode(compute_encoder); + compute_encoder.dispatchThreads(grid_dims, group_dims); +} - // Each thread handles one output - if (is_small) { - grid_dims = MTL::Size(out.size(), 1, 1); - group_dims = MTL::Size(std::min(1024ul, out.size()), 1, 1); - } - // Each simdgroup handles one output - else if (is_med) { - grid_dims = MTL::Size(out.size() * 32, 1, 1); - group_dims = MTL::Size(std::min(8ul, out.size()) * 32, 1, 1); - } - // Each theadgroup handles one output - else { - int n_reads = REDUCE_N_READS; - NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); - thread_group_size = - std::min((reduction_size + n_reads - 1) / n_reads, thread_group_size); +void strided_reduce_looped( + const array& in, + array& out, + const std::string& op_name, + ColReduceArgs& args, + CommandEncoder& compute_encoder, + metal::Device& d, + const Stream& s) { + // Prepare the arguments for the kernel + args.reduce_shape.push_back(args.reduction_size); + args.reduce_strides.push_back(args.reduction_stride); + args.reduce_ndim++; - // Align thread group size with simd_size - uint simd_size = kernel->threadExecutionWidth(); - thread_group_size = - (thread_group_size + simd_size - 1) / simd_size * simd_size; - assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup()); + // Figure out the grid dims + auto out_grid_size = output_grid_for_col_reduce(out, args); + int BN = (args.reduction_stride <= 256) ? 32 : 128; + int BM = 1024 / BN; + int threadgroup_size = 4 * 32; + MTL::Size grid_dims( + threadgroup_size * ((args.reduction_stride + BN - 1) / BN), + out_grid_size.width, + out_grid_size.height); + MTL::Size group_dims(threadgroup_size, 1, 1); - // Launch enough thread groups for each output - size_t n_threads = out.size() * thread_group_size; - grid_dims = MTL::Size(n_threads, non_row_reductions, 1); - group_dims = MTL::Size(thread_group_size, 1, 1); - } + // Set the kernel + int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0; + std::ostringstream kname; + kname << "colLooped" << n << "_" << BM << "_" << BN << "_reduce_" << op_name + << type_to_name(in); + auto kernel = get_reduce_kernel(d, kname.str(), op_name, in, out); + compute_encoder->setComputePipelineState(kernel); - // Dispatch kernel - if (!is_out_64b_int || non_row_reductions == 1) { - // Set the arguments for the kernel - compute_encoder.set_input_array(in, 0); - compute_encoder.set_output_array(out, 1); - compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); - compute_encoder->setBytes(&out_size, sizeof(size_t), 3); - compute_encoder->setBytes(&non_row_reductions, sizeof(size_t), 4); - compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5); - compute_encoder->setBytes( - strides.data(), strides.size() * sizeof(size_t), 6); - compute_encoder->setBytes(&ndim, sizeof(int), 7); - compute_encoder.dispatchThreads(grid_dims, group_dims); - - } else { - // Allocate intermediate array to store partial reduction results - array intermediate = array( - {static_cast(out.size()), static_cast(non_row_reductions)}, - out_dtype, - nullptr, - {}); - intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes())); - std::vector intermediates = {intermediate}; - - // Set the arguments for the kernel - compute_encoder.set_input_array(in, 0); - compute_encoder.set_output_array(intermediate, 1); - compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); - compute_encoder->setBytes(&out_size, sizeof(size_t), 3); - compute_encoder->setBytes(&non_row_reductions, sizeof(size_t), 4); - compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5); - compute_encoder->setBytes( - strides.data(), strides.size() * sizeof(size_t), 6); - compute_encoder->setBytes(&ndim, sizeof(int), 7); - compute_encoder.dispatchThreads(grid_dims, group_dims); - - // Set up second dispatch - reduction_size = non_row_reductions; - out_size = 1; - - // Shape of axes that aren't participating in reduction remains unchanged. - std::vector new_shape = rem_shape; - - // Update their strides since they'll be different post partial reduction in - // first compute dispatch. - std::vector new_strides = rem_strides; - new_strides.back() = reduction_size; - for (int i = new_shape.size() - 2; i >= 0; i--) { - new_strides[i] = new_shape[i + 1] * new_strides[i + 1]; - } - ndim = new_shape.size(); - - // Set the arguments for the kernel - compute_encoder.set_input_array(intermediate, 0); - compute_encoder.set_output_array(out, 1); - compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); - compute_encoder->setBytes(&out_size, sizeof(size_t), 3); - compute_encoder->setBytes(&non_row_reductions, sizeof(size_t), 4); - compute_encoder->setBytes( - new_shape.data(), new_shape.size() * sizeof(int), 5); - compute_encoder->setBytes( - new_strides.data(), new_strides.size() * sizeof(size_t), 6); - compute_encoder->setBytes(&ndim, sizeof(int), 7); - - // Each thread group is responsible for 1 output - int n_reads = REDUCE_N_READS; - size_t thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); - thread_group_size = - std::min((reduction_size + n_reads - 1) / n_reads, thread_group_size); - - // Align thread group size with simd_size - uint simd_size = kernel->threadExecutionWidth(); - thread_group_size = - (thread_group_size + simd_size - 1) / simd_size * simd_size; - assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup()); - - // Launch enough thread groups for each output - size_t n_threads = thread_group_size; - grid_dims = MTL::Size(n_threads, out.size(), 1); - group_dims = MTL::Size(thread_group_size, 1, 1); - - compute_encoder.dispatchThreads(grid_dims, group_dims); - - d.get_command_buffer(s.index)->addCompletedHandler( - [intermediates](MTL::CommandBuffer*) mutable { - intermediates.clear(); - }); - } + // Launch + compute_encoder.set_input_array(in, 0); + compute_encoder.set_output_array(out, 1); + args.encode(compute_encoder); + compute_encoder.dispatchThreads(grid_dims, group_dims); } void strided_reduce_general_dispatch( @@ -310,248 +541,23 @@ void strided_reduce_general_dispatch( CommandEncoder& compute_encoder, metal::Device& d, const Stream& s) { - Dtype out_dtype = out.dtype(); - // Prepare the arguments for the kernel - size_t reduction_size = plan.shape.back(); - size_t reduction_stride = plan.strides.back(); - size_t out_size = out.size(); - auto shape = plan.shape; - auto strides = plan.strides; - shape.pop_back(); - strides.pop_back(); - size_t non_col_reductions = 1; - for (auto s : shape) { - non_col_reductions *= static_cast(s); + ColReduceArgs args(in, plan, axes); + + if (args.reduction_stride < 32) { + return strided_reduce_small(in, out, op_name, args, compute_encoder, d, s); } - std::vector non_col_shapes = shape; - std::vector non_col_strides = strides; - int non_col_ndim = shape.size(); - - auto [rem_shape, rem_strides] = shapes_without_reduction_axes(in, axes); - for (auto s : rem_shape) { - shape.push_back(s); - } - for (auto s : rem_strides) { - strides.push_back(s); - } - int ndim = shape.size(); - - // Specialize for small dims - if (reduction_size * non_col_reductions < 16) { - // Select kernel - auto kernel = get_reduce_kernel( - d, "colSmall_reduce_" + op_name + type_to_name(in), op_name, in, out); - compute_encoder->setComputePipelineState(kernel); - - // Select block dims - MTL::Size grid_dims = MTL::Size(out_size, 1, 1); - MTL::Size group_dims = MTL::Size(256ul, 1, 1); - - if (non_col_ndim == 0) { - non_col_shapes = {1}; - non_col_strides = {1}; - } - - // Encode arrays - compute_encoder.set_input_array(in, 0); - compute_encoder.set_output_array(out, 1); - compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); - compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3); - compute_encoder->setBytes(&out_size, sizeof(size_t), 4); - compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5); - compute_encoder->setBytes( - strides.data(), strides.size() * sizeof(size_t), 6); - compute_encoder->setBytes(&ndim, sizeof(int), 7); - compute_encoder->setBytes(&non_col_reductions, sizeof(size_t), 8); - compute_encoder->setBytes( - non_col_shapes.data(), non_col_shapes.size() * sizeof(int), 9); - compute_encoder->setBytes( - non_col_strides.data(), non_col_shapes.size() * sizeof(size_t), 10); - compute_encoder->setBytes(&non_col_ndim, sizeof(int), 11); - - // Dispatch threads - compute_encoder.dispatchThreads(grid_dims, group_dims); - - return; - } - - // Select kernel - bool is_out_64b_int = is_64b_int(out_dtype); - std::string kernel_name = "colGeneral"; - if (is_out_64b_int) { - kernel_name += "NoAtomics"; - } - kernel_name += "_reduce_" + op_name + type_to_name(in); - auto kernel = get_reduce_kernel(d, kernel_name, op_name, in, out); - - compute_encoder->setComputePipelineState(kernel); - - // Select block dimensions - // Each thread reads 16 inputs to give it more work - uint n_inputs_per_thread = REDUCE_N_READS; - uint n_threads_per_output = - (reduction_size + n_inputs_per_thread - 1) / n_inputs_per_thread; - - // We spread outputs over the x dimension and inputs over the y dimension - // Threads with the same lid.x in a given threadgroup work on the same - // output and each thread in the y dimension accumulates for that output - - // Threads with same lid.x, i.e. each column of threads work on same output - uint threadgroup_dim_x = std::min(out_size, 128ul); - - // Number of threads along y, is dependent on number of reductions needed. - uint threadgroup_dim_y = - kernel->maxTotalThreadsPerThreadgroup() / threadgroup_dim_x; - threadgroup_dim_y = std::min(n_threads_per_output, threadgroup_dim_y); - - // Derive number of thread groups along x, based on how many threads we need - // along x - uint n_threadgroups_x = - (out_size + threadgroup_dim_x - 1) / threadgroup_dim_x; - - // Derive number of thread groups along y based on how many threads we need - // along y - uint n_threadgroups_y = - (n_threads_per_output + threadgroup_dim_y - 1) / threadgroup_dim_y; - - // Launch enough thread groups for each output - MTL::Size grid_dims = - MTL::Size(n_threadgroups_x, n_threadgroups_y, non_col_reductions); - MTL::Size group_dims = MTL::Size(threadgroup_dim_x, threadgroup_dim_y, 1); - - if (is_out_64b_int == false) { - // Set the arguments for the kernel - compute_encoder.set_input_array(in, 0); - compute_encoder.set_output_array(out, 1); - compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); - compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3); - compute_encoder->setBytes(&out_size, sizeof(size_t), 4); - compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5); - compute_encoder->setBytes( - strides.data(), strides.size() * sizeof(size_t), 6); - compute_encoder->setBytes(&ndim, sizeof(int), 7); - - // We set shared memory to be exploited here for reductions within a - // threadgroup - each thread must be able to update its accumulated output - // Note: Each threadgroup should have 32kB of data in threadgroup memory - // and threadgroup_dim_x * threadgroup_dim_y <= 1024 by design - // This should be fine for floats, but we might need to revisit - // if we ever come to doubles. In that case, we should also cut - // down the number of threads we launch in a threadgroup - compute_encoder->setThreadgroupMemoryLength( - safe_divup(threadgroup_dim_x * threadgroup_dim_y * out.itemsize(), 16), - 0); - compute_encoder.dispatchThreadgroups(grid_dims, group_dims); - - } else { - // Allocate intermediate array to store reduction results from all thread - // groups - array intermediate = array( - {static_cast(out.size()), - static_cast(n_threadgroups_y * non_col_reductions)}, - out_dtype, - nullptr, - {}); - intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes())); - std::vector intermediates = {intermediate}; - - // Set the arguments for the kernel - compute_encoder.set_input_array(in, 0); - compute_encoder.set_output_array(intermediate, 1); - compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); - compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3); - compute_encoder->setBytes(&out_size, sizeof(size_t), 4); - compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5); - compute_encoder->setBytes( - strides.data(), strides.size() * sizeof(size_t), 6); - compute_encoder->setBytes(&ndim, sizeof(int), 7); - - // We set shared memory to be exploited here for reductions within a - // threadgroup - each thread must be able to update its accumulated output - // Note: Each threadgroup should have 32kB of data in threadgroup memory - // and threadgroup_dim_x * threadgroup_dim_y <= 1024 by design - // This should be fine for floats, but we might need to revisit - // if we ever come to doubles. In that case, we should also cut - // down the number of threads we launch in a threadgroup - compute_encoder->setThreadgroupMemoryLength( - safe_divup(threadgroup_dim_x * threadgroup_dim_y * out.itemsize(), 16), - 0); - compute_encoder.dispatchThreadgroups(grid_dims, group_dims); - - // Perform second pass of reductions - // Reduce results of threadgroups along y, z from first pass, that - // collectively work on each output element. - reduction_size = n_threadgroups_y * non_col_reductions; - out_size = 1; - - // Shape of axes that aren't participating in reduction remains unchanged. - std::vector new_shape = rem_shape; - - // Update their strides since they'll be different after a partial reduction - // post first compute dispatch. - std::vector new_strides = rem_strides; - new_strides.back() = reduction_size; - for (int i = new_shape.size() - 2; i >= 0; i--) { - new_strides[i] = new_shape[i + 1] * new_strides[i + 1]; - } - ndim = new_shape.size(); - - std::string kernel_name = - "rowGeneralNoAtomics_reduce_" + op_name + type_to_name(intermediate); - auto row_reduce_kernel = - get_reduce_kernel(d, kernel_name, op_name, intermediate, out); - - compute_encoder->setComputePipelineState(row_reduce_kernel); - compute_encoder.set_input_array(intermediate, 0); - compute_encoder.set_output_array(out, 1); - compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); - compute_encoder->setBytes(&out_size, sizeof(size_t), 3); - compute_encoder->setBytes(&reduction_size, sizeof(size_t), 4); - compute_encoder->setBytes( - new_shape.data(), new_shape.size() * sizeof(int), 5); - compute_encoder->setBytes( - new_strides.data(), new_strides.size() * sizeof(size_t), 6); - compute_encoder->setBytes(&ndim, sizeof(int), 7); - - // Each thread group is responsible for 1 output - size_t n_reads = REDUCE_N_READS; - size_t thread_group_size = - row_reduce_kernel->maxTotalThreadsPerThreadgroup(); - thread_group_size = - std::min((reduction_size + n_reads - 1) / n_reads, thread_group_size); - - // Align thread group size with simd_size - uint simd_size = row_reduce_kernel->threadExecutionWidth(); - thread_group_size = - (thread_group_size + simd_size - 1) / simd_size * simd_size; - assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup()); - - // Launch enough thread groups for each output - uint n_threads = thread_group_size; - grid_dims = MTL::Size(n_threads, out.size(), 1); - group_dims = MTL::Size(thread_group_size, 1, 1); - - compute_encoder.dispatchThreads(grid_dims, group_dims); - - d.get_command_buffer(s.index)->addCompletedHandler( - [intermediates](MTL::CommandBuffer*) mutable { - intermediates.clear(); - }); - } + return strided_reduce_looped(in, out, op_name, args, compute_encoder, d, s); } -////////////////////////////////////////////////////////////////////// -// Main reduce dispatch -////////////////////////////////////////////////////////////////////// - void Reduce::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); array in = inputs[0]; // Make sure no identity reductions trickle down here assert(!axes_.empty()); + assert(out.size() != in.size()); // Continue with reduction operation // Minimum of 4 bytes since we use size 4 structs for all reduce @@ -584,20 +590,6 @@ void Reduce::eval_gpu(const std::vector& inputs, array& out) { auto& s = stream(); auto& d = metal::device(s.device); auto& compute_encoder = d.get_command_encoder(s.index); - { - auto kernel = get_reduce_init_kernel( - d, "i_reduce_" + op_name + type_to_name(out), out); - size_t nthreads = out.size(); - MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); - NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); - if (thread_group_size > nthreads) { - thread_group_size = nthreads; - } - MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - compute_encoder->setComputePipelineState(kernel); - compute_encoder.set_output_array(out, 0); - compute_encoder.dispatchThreads(grid_dims, group_dims); - } // Reduce if (in.size() > 0) { @@ -606,6 +598,10 @@ void Reduce::eval_gpu(const std::vector& inputs, array& out) { // If it is a general reduce then copy the input to a contiguous array and // recompute the plan. + // + // TODO: This can be avoided by making the output have the same strides as + // input for the axes with stride smaller than the minimum reduction + // stride. if (plan.type == GeneralReduce) { array in_copy(in.shape(), in.dtype(), nullptr, {}); copy_gpu(in, in_copy, CopyType::General, s); @@ -617,7 +613,7 @@ void Reduce::eval_gpu(const std::vector& inputs, array& out) { // Reducing over everything and the data is all there no broadcasting or // slicing etc. if (plan.type == ContiguousAllReduce) { - all_reduce_dispatch(in, out, op_name, compute_encoder, d, s); + all_reduce_dispatch(in, out, op_name, compute_encoder, d, s, copies); } // At least the last dimension is row contiguous and we are reducing over @@ -642,6 +638,11 @@ void Reduce::eval_gpu(const std::vector& inputs, array& out) { [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); } } + + // Nothing to reduce just initialize the output + else { + init_reduce(out, op_name, compute_encoder, d, s); + } } } // namespace mlx::core diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 9ffa693ad..e77d18b0f 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -16,9 +16,11 @@ namespace mlx::core { namespace { -std::pair, std::vector> compute_reduce_shape( +std::tuple, std::vector, std::vector, bool> +compute_reduce_shape( const std::vector& axes, const std::vector& shape) { + bool is_noop = true; std::set axes_set; auto ndim = shape.size(); for (auto ax : axes) { @@ -35,15 +37,18 @@ std::pair, std::vector> compute_reduce_shape( throw std::invalid_argument("Duplicate axes detected in reduction."); } std::vector out_shape; + std::vector squeezed_shape; for (int i = 0; i < ndim; ++i) { if (axes_set.count(i) == 0) { out_shape.push_back(shape[i]); + squeezed_shape.push_back(shape[i]); } else { out_shape.push_back(1); } + is_noop &= (out_shape.back() == shape[i]); } std::vector sorted_axes(axes_set.begin(), axes_set.end()); - return {out_shape, sorted_axes}; + return {out_shape, sorted_axes, squeezed_shape, is_noop}; } Dtype at_least_float(const Dtype& d) { @@ -1502,17 +1507,17 @@ array all( const std::vector& axes, bool keepdims /* = false */, StreamOrDevice s /* = {}*/) { - if (axes.empty()) { - return astype(a, bool_, s); - } - auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape()); - auto out = array( - out_shape, - bool_, - std::make_shared(to_stream(s), Reduce::And, sorted_axes), - {a}); + auto [out_shape, sorted_axes, squeezed_shape, is_noop] = + compute_reduce_shape(axes, a.shape()); + auto out = (is_noop) + ? astype(a, bool_, s) + : array( + std::move(out_shape), + bool_, + std::make_shared(to_stream(s), Reduce::And, sorted_axes), + {a}); if (!keepdims) { - out = squeeze(out, sorted_axes, s); + out = reshape(out, std::move(squeezed_shape), s); } return out; } @@ -1536,17 +1541,17 @@ array any( const std::vector& axes, bool keepdims /* = false */, StreamOrDevice s /* = {}*/) { - if (axes.empty()) { - return astype(a, bool_, s); - } - auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape()); - auto out = array( - out_shape, - bool_, - std::make_shared(to_stream(s), Reduce::Or, sorted_axes), - {a}); + auto [out_shape, sorted_axes, squeezed_shape, is_noop] = + compute_reduce_shape(axes, a.shape()); + auto out = (is_noop) + ? astype(a, bool_, s) + : array( + std::move(out_shape), + bool_, + std::make_shared(to_stream(s), Reduce::Or, sorted_axes), + {a}); if (!keepdims) { - out = squeeze(out, sorted_axes, s); + out = reshape(out, std::move(squeezed_shape), s); } return out; } @@ -1573,15 +1578,18 @@ array sum( if (axes.empty()) { return a; } - auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape()); + auto [out_shape, sorted_axes, squeezed_shape, is_noop] = + compute_reduce_shape(axes, a.shape()); auto out_type = a.dtype() == bool_ ? int32 : a.dtype(); - auto out = array( - out_shape, - out_type, - std::make_shared(to_stream(s), Reduce::Sum, sorted_axes), - {a}); + auto out = (is_noop) + ? astype(a, out_type, s) + : array( + std::move(out_shape), + out_type, + std::make_shared(to_stream(s), Reduce::Sum, sorted_axes), + {a}); if (!keepdims) { - out = squeeze(out, sorted_axes, s); + out = reshape(out, std::move(squeezed_shape), s); } return out; } @@ -1715,14 +1723,17 @@ array prod( if (axes.empty()) { return a; } - auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape()); - auto out = array( - out_shape, - a.dtype(), - std::make_shared(to_stream(s), Reduce::Prod, sorted_axes), - {a}); + auto [out_shape, sorted_axes, squeezed_shape, is_noop] = + compute_reduce_shape(axes, a.shape()); + auto out = (is_noop) + ? a + : array( + std::move(out_shape), + a.dtype(), + std::make_shared(to_stream(s), Reduce::Prod, sorted_axes), + {a}); if (!keepdims) { - out = squeeze(out, sorted_axes, s); + out = reshape(out, std::move(squeezed_shape), s); } return out; } @@ -1749,17 +1760,17 @@ array max( if (a.size() == 0) { throw std::invalid_argument("[max] Cannot max reduce zero size array."); } - if (axes.empty()) { - return a; - } - auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape()); - auto out = array( - out_shape, - a.dtype(), - std::make_shared(to_stream(s), Reduce::Max, sorted_axes), - {a}); + auto [out_shape, sorted_axes, squeezed_shape, is_noop] = + compute_reduce_shape(axes, a.shape()); + auto out = (is_noop) + ? a + : array( + std::move(out_shape), + a.dtype(), + std::make_shared(to_stream(s), Reduce::Max, sorted_axes), + {a}); if (!keepdims) { - out = squeeze(out, sorted_axes, s); + out = reshape(out, std::move(squeezed_shape), s); } return out; } @@ -1789,14 +1800,17 @@ array min( if (axes.empty()) { return a; } - auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape()); - auto out = array( - out_shape, - a.dtype(), - std::make_shared(to_stream(s), Reduce::Min, sorted_axes), - {a}); + auto [out_shape, sorted_axes, squeezed_shape, is_noop] = + compute_reduce_shape(axes, a.shape()); + auto out = (is_noop) + ? a + : array( + std::move(out_shape), + a.dtype(), + std::make_shared(to_stream(s), Reduce::Min, sorted_axes), + {a}); if (!keepdims) { - out = squeeze(out, sorted_axes, s); + out = reshape(out, std::move(squeezed_shape), s); } return out; } @@ -1829,15 +1843,18 @@ array argmin( throw std::invalid_argument( "[argmin] Cannot argmin reduce zero size array."); } - auto [out_shape, sorted_axes] = compute_reduce_shape({axis}, a.shape()); - auto out = array( - out_shape, - uint32, - std::make_shared( - to_stream(s), ArgReduce::ArgMin, sorted_axes[0]), - {a}); + auto [out_shape, sorted_axes, squeezed_shape, is_noop] = + compute_reduce_shape({axis}, a.shape()); + auto out = (is_noop) + ? zeros(out_shape, uint32, s) + : array( + std::move(out_shape), + uint32, + std::make_shared( + to_stream(s), ArgReduce::ArgMin, sorted_axes[0]), + {a}); if (!keepdims) { - out = squeeze(out, sorted_axes, s); + out = reshape(out, std::move(squeezed_shape), s); } return out; } @@ -1862,15 +1879,18 @@ array argmax( throw std::invalid_argument( "[argmax] Cannot argmax reduce zero size array."); } - auto [out_shape, sorted_axes] = compute_reduce_shape({axis}, a.shape()); - auto out = array( - out_shape, - uint32, - std::make_shared( - to_stream(s), ArgReduce::ArgMax, sorted_axes[0]), - {a}); + auto [out_shape, sorted_axes, squeezed_shape, is_noop] = + compute_reduce_shape({axis}, a.shape()); + auto out = (is_noop) + ? zeros(out_shape, uint32, s) + : array( + std::move(out_shape), + uint32, + std::make_shared( + to_stream(s), ArgReduce::ArgMax, sorted_axes[0]), + {a}); if (!keepdims) { - out = squeeze(out, sorted_axes, s); + out = reshape(out, std::move(squeezed_shape), s); } return out; } diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 8abe74471..2837d9a30 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1763,7 +1763,7 @@ class TestOps(mlx_tests.MLXTestCase): mat_t = mat.astype(t) out = mx.cumsum(a_t, axis=-1) expected = (mat_t * a_t[:, None, :]).sum(axis=-1) - self.assertTrue(mx.allclose(out, expected, rtol=1e-2, atol=1e-3)) + self.assertTrue(mx.allclose(out, expected, rtol=0.02, atol=1e-3)) sizes = [1023, 1024, 1025, 2047, 2048, 2049] for s in sizes: a = mx.ones((s,), mx.int32) diff --git a/python/tests/test_reduce.py b/python/tests/test_reduce.py index 70c8a1172..0ef080f1a 100644 --- a/python/tests/test_reduce.py +++ b/python/tests/test_reduce.py @@ -43,6 +43,10 @@ class TestReduce(mlx_tests.MLXTestCase): z_npy = np.sum(y_npy, axis=a) / 1000 z_mlx = mx.sum(y_mlx, axis=a) / 1000 mx.eval(z_mlx) + if not np.allclose(z_npy, np.array(z_mlx), atol=1e-4): + import pdb + + pdb.set_trace() self.assertTrue( np.allclose(z_npy, np.array(z_mlx), atol=1e-4) ) @@ -57,6 +61,7 @@ class TestReduce(mlx_tests.MLXTestCase): "uint32", "int64", "uint64", + "complex64", ] float_dtypes = ["float32"] @@ -114,6 +119,15 @@ class TestReduce(mlx_tests.MLXTestCase): b = getattr(np, op)(data) self.assertEqual(a.item(), b) + def test_edge_case(self): + x = (mx.random.normal((100, 1, 100, 100)) * 128).astype(mx.int32) + x = x.transpose(0, 3, 1, 2) + + y = x.sum((0, 2, 3)) + mx.eval(y) + z = np.array(x).sum((0, 2, 3)) + self.assertTrue(np.all(z == y)) + if __name__ == "__main__": unittest.main(failfast=True)