From 0c5eea226ba8df2e2468a0b28d0456fec7b5d5fc Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 21 Nov 2024 19:53:00 -0800 Subject: [PATCH] Reduce specializations (#1607) * start of reduce specializations * fix all reduce * fix many dims * fix * non-jit tests clear * cleanup instantiations * cpu merges * change dim specializations * optimize * fix jit * fix jit * use higher precision for integer sum+prod * fixes --- mlx/backend/common/reduce.cpp | 214 ++++++++---- mlx/backend/metal/jit_kernels.cpp | 46 +-- mlx/backend/metal/kernels.h | 7 +- mlx/backend/metal/kernels/reduce.metal | 258 +++++++-------- .../metal/kernels/reduction/reduce_all.h | 15 +- .../metal/kernels/reduction/reduce_col.h | 96 +++--- .../metal/kernels/reduction/reduce_row.h | 40 ++- mlx/backend/metal/kernels/utils.h | 94 ++++-- mlx/backend/metal/nojit_kernels.cpp | 8 +- mlx/backend/metal/reduce.cpp | 311 ++++++++++++++---- mlx/backend/metal/utils.cpp | 8 +- mlx/backend/metal/utils.h | 1 + mlx/ops.cpp | 19 +- python/tests/test_reduce.py | 22 ++ 14 files changed, 733 insertions(+), 406 deletions(-) diff --git a/mlx/backend/common/reduce.cpp b/mlx/backend/common/reduce.cpp index 87651d47e..049bb7409 100644 --- a/mlx/backend/common/reduce.cpp +++ b/mlx/backend/common/reduce.cpp @@ -120,48 +120,56 @@ struct MinReduce { }; template -void reduce_dispatch_out( +void reduce_dispatch_and_or( const array& in, array& out, Reduce::ReduceType rtype, const std::vector& axes) { - switch (rtype) { - case Reduce::And: { - reduction_op(in, out, axes, true, AndReduce()); - break; + if (rtype == Reduce::And) { + reduction_op(in, out, axes, true, AndReduce()); + } else { + reduction_op(in, out, axes, false, OrReduce()); + } +} + +template +void reduce_dispatch_sum_prod( + const array& in, + array& out, + Reduce::ReduceType rtype, + const std::vector& axes) { + if (rtype == Reduce::Sum) { + auto op = [](auto y, auto x) { (*y) = (*y) + x; }; + if constexpr (std::is_integral_v && sizeof(InT) <= 4) { + reduction_op(in, out, axes, 0, op); + } else { + reduction_op(in, out, axes, 0, op); } - case Reduce::Or: { - reduction_op(in, out, axes, false, OrReduce()); - break; - } - case Reduce::Sum: { - auto op = [](auto y, auto x) { (*y) = (*y) + x; }; - if (out.dtype() == int32) { - // special case since the input type can be bool - reduction_op(in, out, axes, 0, op); - } else { - reduction_op(in, out, axes, 0, op); - } - break; - } - case Reduce::Prod: { - auto op = [](auto y, auto x) { (*y) *= x; }; + } else { + auto op = [](auto y, auto x) { (*y) *= x; }; + if constexpr (std::is_integral_v && sizeof(InT) <= 4) { + reduction_op(in, out, axes, 1, op); + } else { reduction_op(in, out, axes, 1, op); - break; - } - case Reduce::Max: { - auto init = Limits::min; - reduction_op(in, out, axes, init, MaxReduce()); - break; - } - case Reduce::Min: { - auto init = Limits::max; - reduction_op(in, out, axes, init, MinReduce()); - break; } } } +template +void reduce_dispatch_min_max( + const array& in, + array& out, + Reduce::ReduceType rtype, + const std::vector& axes) { + if (rtype == Reduce::Max) { + auto init = Limits::min; + reduction_op(in, out, axes, init, MaxReduce()); + } else { + auto init = Limits::max; + reduction_op(in, out, axes, init, MinReduce()); + } +} + } // namespace void nd_loop( @@ -190,46 +198,114 @@ void nd_loop( void Reduce::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; - switch (in.dtype()) { - case bool_: - reduce_dispatch_out(in, out, reduce_type_, axes_); + switch (reduce_type_) { + case Reduce::And: + case Reduce::Or: { + switch (in.dtype()) { + case bool_: + case uint8: + case int8: + reduce_dispatch_and_or(in, out, reduce_type_, axes_); + break; + case int16: + case uint16: + case float16: + case bfloat16: + reduce_dispatch_and_or(in, out, reduce_type_, axes_); + break; + case uint32: + case int32: + case float32: + reduce_dispatch_and_or(in, out, reduce_type_, axes_); + break; + case uint64: + case int64: + case complex64: + reduce_dispatch_and_or(in, out, reduce_type_, axes_); + break; + } break; - case uint8: - reduce_dispatch_out(in, out, reduce_type_, axes_); + } + case Reduce::Sum: + case Reduce::Prod: { + switch (in.dtype()) { + case bool_: + case uint8: + case int8: + reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); + break; + case int16: + case uint16: + reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); + break; + case int32: + case uint32: + reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); + break; + case int64: + case uint64: + reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); + break; + case float16: + reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); + break; + case bfloat16: + reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); + break; + case float32: + reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); + break; + case complex64: + reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); + break; + } break; - case uint16: - reduce_dispatch_out(in, out, reduce_type_, axes_); - break; - case uint32: - reduce_dispatch_out(in, out, reduce_type_, axes_); - break; - case uint64: - reduce_dispatch_out(in, out, reduce_type_, axes_); - break; - case int8: - reduce_dispatch_out(in, out, reduce_type_, axes_); - break; - case int16: - reduce_dispatch_out(in, out, reduce_type_, axes_); - break; - case int32: - reduce_dispatch_out(in, out, reduce_type_, axes_); - break; - case int64: - reduce_dispatch_out(in, out, reduce_type_, axes_); - break; - case float16: - reduce_dispatch_out(in, out, reduce_type_, axes_); - break; - case float32: - reduce_dispatch_out(in, out, reduce_type_, axes_); - break; - case bfloat16: - reduce_dispatch_out(in, out, reduce_type_, axes_); - break; - case complex64: - reduce_dispatch_out(in, out, reduce_type_, axes_); + } + case Reduce::Max: + case Reduce::Min: { + switch (in.dtype()) { + case bool_: + reduce_dispatch_min_max(in, out, reduce_type_, axes_); + break; + case uint8: + reduce_dispatch_min_max(in, out, reduce_type_, axes_); + break; + case uint16: + reduce_dispatch_min_max(in, out, reduce_type_, axes_); + break; + case uint32: + reduce_dispatch_min_max(in, out, reduce_type_, axes_); + break; + case uint64: + reduce_dispatch_min_max(in, out, reduce_type_, axes_); + break; + case int8: + reduce_dispatch_min_max(in, out, reduce_type_, axes_); + break; + case int16: + reduce_dispatch_min_max(in, out, reduce_type_, axes_); + break; + case int32: + reduce_dispatch_min_max(in, out, reduce_type_, axes_); + break; + case int64: + reduce_dispatch_min_max(in, out, reduce_type_, axes_); + break; + case float16: + reduce_dispatch_min_max(in, out, reduce_type_, axes_); + break; + case float32: + reduce_dispatch_min_max(in, out, reduce_type_, axes_); + break; + case bfloat16: + reduce_dispatch_min_max(in, out, reduce_type_, axes_); + break; + case complex64: + reduce_dispatch_min_max(in, out, reduce_type_, axes_); + break; + } break; + } } } diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index aa43cb0f6..b09f920dd 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -1,5 +1,4 @@ // Copyright © 2024 Apple Inc. - #include "mlx/backend/common/compiled.h" #include "mlx/backend/metal/jit/arange.h" #include "mlx/backend/metal/jit/gemv_masked.h" @@ -338,17 +337,17 @@ MTL::ComputePipelineState* get_reduce_init_kernel( const std::string& kernel_name, const std::string& func_name, const std::string& op_name, - const array& out) { + const Dtype& out_type) { auto lib = d.get_library(kernel_name, [&]() { - std::ostringstream kernel_source; std::string op_type = op_name; op_type[0] = std::toupper(op_name[0]); - auto out_type = get_type_string(out.dtype()); - std::string op = op_type + "<" + out_type + ">"; - kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce(); - kernel_source << get_template_definition( - kernel_name, func_name, out_type, op); - return kernel_source.str(); + auto out_t = get_type_string(out_type); + std::string op = op_type + "<" + out_t + ">"; + std::string kernel_source = metal::utils(); + kernel_source += metal::reduce_utils(); + kernel_source += metal::reduce(); + kernel_source += get_template_definition(kernel_name, func_name, out_t, op); + return kernel_source; }); return d.get_kernel(kernel_name, lib); } @@ -358,30 +357,31 @@ MTL::ComputePipelineState* get_reduce_kernel( const std::string& kernel_name, const std::string& func_name, const std::string& op_name, - const array& in, - const array& out, + const Dtype& in_type, + const Dtype& out_type, + const std::string& idx_t, int ndim /* = -1 */, int bm /* = -1 */, int bn /* = -1 */) { auto lib = d.get_library(kernel_name, [&]() { std::string op_type = op_name; op_type[0] = std::toupper(op_name[0]); - std::ostringstream kernel_source; - auto in_type = get_type_string(in.dtype()); - auto out_type = get_type_string(out.dtype()); - std::string op = op_type + "<" + out_type + ">"; - kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce(); + auto in_t = get_type_string(in_type); + auto out_t = get_type_string(out_type); + std::string op = op_type + "<" + out_t + ">"; + std::string kernel_source = metal::utils(); + concatenate(kernel_source, metal::reduce_utils(), metal::reduce()); if (bm >= 0) { - kernel_source << get_template_definition( - kernel_name, func_name, in_type, out_type, op, ndim, bm, bn); + kernel_source += get_template_definition( + kernel_name, func_name, in_t, out_t, op, idx_t, ndim, bm, bn); } else if (ndim >= 0) { - kernel_source << get_template_definition( - kernel_name, func_name, in_type, out_type, op, ndim); + kernel_source += get_template_definition( + kernel_name, func_name, in_t, out_t, op, idx_t, ndim); } else { - kernel_source << get_template_definition( - kernel_name, func_name, in_type, out_type, op); + kernel_source += get_template_definition( + kernel_name, func_name, in_t, out_t, op, idx_t); } - return kernel_source.str(); + return kernel_source; }); auto st = d.get_kernel(kernel_name, lib); return st; diff --git a/mlx/backend/metal/kernels.h b/mlx/backend/metal/kernels.h index b5f9b0a92..20e3bd907 100644 --- a/mlx/backend/metal/kernels.h +++ b/mlx/backend/metal/kernels.h @@ -81,15 +81,16 @@ MTL::ComputePipelineState* get_reduce_init_kernel( const std::string& kernel_name, const std::string& func_name, const std::string& op_name, - const array& out); + const Dtype& out_type); MTL::ComputePipelineState* get_reduce_kernel( metal::Device& d, const std::string& kernel_name, const std::string& func_name, const std::string& op_name, - const array& in, - const array& out, + const Dtype& in_type, + const Dtype& out_type, + const std::string& idx_t, int ndim = -1, int bm = -1, int bn = -1); diff --git a/mlx/backend/metal/kernels/reduce.metal b/mlx/backend/metal/kernels/reduce.metal index d68045047..4af18c970 100644 --- a/mlx/backend/metal/kernels/reduce.metal +++ b/mlx/backend/metal/kernels/reduce.metal @@ -10,186 +10,156 @@ #include "mlx/backend/metal/kernels/reduction/ops.h" #include "mlx/backend/metal/kernels/reduce.h" -#define instantiate_reduce_helper_floats(inst_f, name, op) \ - inst_f(name, float16, half, op) \ - inst_f(name, float32, float, op) \ - inst_f(name, bfloat16, bfloat16_t, op) +#define instantiate_init_reduce(name, tname, type, op) \ + instantiate_kernel("init_reduce_" #name #tname, init_reduce, type, op) -#define instantiate_reduce_helper_uints(inst_f, name, op) \ - inst_f(name, uint8, uint8_t, op) \ - inst_f(name, uint16, uint16_t, op) \ - inst_f(name, uint32, uint32_t, op) +instantiate_init_reduce(and, bool_, bool, And) +instantiate_init_reduce(or, bool_, bool, Or) -#define instantiate_reduce_helper_ints(inst_f, name, op) \ - inst_f(name, int8, int8_t, op) \ - inst_f(name, int16, int16_t, op) \ - inst_f(name, int32, int32_t, op) +#define instantiate_init_sum_prod(name, op) \ + instantiate_init_reduce(name, int32, int32_t, op) \ + instantiate_init_reduce(name, int64, int64_t, op) \ + instantiate_init_reduce(name, float16, float16_t, op) \ + instantiate_init_reduce(name, bfloat16, bfloat16_t, op) \ + instantiate_init_reduce(name, float32, float, op) \ + instantiate_init_reduce(name, complex64, complex64_t, op) -#define instantiate_reduce_helper_64b(inst_f, name, op) \ - inst_f(name, int64, int64_t, op) \ - inst_f(name, uint64, uint64_t, op) \ - inst_f(name, complex64, complex64_t, op) +instantiate_init_sum_prod(sum, Sum) +instantiate_init_sum_prod(prod, Prod) -#define instantiate_reduce_helper_types(inst_f, name, op) \ - instantiate_reduce_helper_floats(inst_f, name, op) \ - instantiate_reduce_helper_uints(inst_f, name, op) \ - instantiate_reduce_helper_ints(inst_f, name, op) +#define instantiate_init_min_max(name, op) \ + instantiate_init_reduce(name, bool_, bool, op) \ + instantiate_init_reduce(name, int8, int8_t, op) \ + instantiate_init_reduce(name, int16, int16_t, op) \ + instantiate_init_reduce(name, int32, int32_t, op) \ + instantiate_init_reduce(name, int64, int64_t, op) \ + instantiate_init_reduce(name, uint8, uint8_t, op) \ + instantiate_init_reduce(name, uint16, uint16_t, op) \ + instantiate_init_reduce(name, uint32, uint32_t, op) \ + instantiate_init_reduce(name, uint64, uint64_t, op) \ + instantiate_init_reduce(name, float16, float16_t, op) \ + instantiate_init_reduce(name, bfloat16, bfloat16_t, op) \ + instantiate_init_reduce(name, float32, float, op) \ + instantiate_init_reduce(name, complex64, complex64_t, op) -#define instantiate_reduce_ops(inst_f, type_f) \ - type_f(inst_f, sum, Sum) \ - type_f(inst_f, prod, Prod) \ - type_f(inst_f, min, Min) \ - type_f(inst_f, max, Max) - -// Special case for bool reductions -#define instantiate_reduce_from_types_helper( \ - inst_f, name, tname, itype, otype, op) \ - inst_f(name##tname, itype, otype, op) - -#define instantiate_reduce_from_types(inst_f, name, otype, op) \ - instantiate_reduce_from_types_helper( \ - inst_f, name, bool_, bool, otype, op) \ - instantiate_reduce_from_types_helper( \ - inst_f, name, uint8, uint8_t, otype, op) \ - instantiate_reduce_from_types_helper( \ - inst_f, name, uint16, uint16_t, otype, op) \ - instantiate_reduce_from_types_helper( \ - inst_f, name, uint32, uint32_t, otype, op) \ - instantiate_reduce_from_types_helper( \ - inst_f, name, uint64, uint64_t, otype, op) \ - instantiate_reduce_from_types_helper( \ - inst_f, name, int8, int8_t, otype, op) \ - instantiate_reduce_from_types_helper( \ - inst_f, name, int16, int16_t, otype, op) \ - instantiate_reduce_from_types_helper( \ - inst_f, name, int32, int32_t, otype, op) \ - instantiate_reduce_from_types_helper( \ - inst_f, name, int64, int64_t, otype, op) \ - instantiate_reduce_from_types_helper( \ - inst_f, name, float16, half, otype, op) \ - instantiate_reduce_from_types_helper( \ - inst_f, \ - name, \ - float32, \ - float, \ - otype, \ - op) \ - instantiate_reduce_from_types_helper( \ - inst_f, \ - name, \ - bfloat16, \ - bfloat16_t, \ - otype, \ - op) - -#define instantiate_init_reduce(name, otype, op) \ - instantiate_kernel("init_reduce_" #name, \ - init_reduce, \ - otype, op) - -#define instantiate_init_reduce_helper(name, tname, type, op) \ - instantiate_init_reduce(name##tname, type, op) - -instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_types) -instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_64b) - -instantiate_init_reduce(andbool_, bool, And) -instantiate_init_reduce(orbool_, bool, Or) +instantiate_init_min_max(min, Min) +instantiate_init_min_max(max, Max) #define instantiate_all_reduce(name, itype, otype, op) \ instantiate_kernel("all_reduce_" #name, \ all_reduce, \ itype, otype, op) -#define instantiate_same_all_reduce_helper(name, tname, type, op) \ - instantiate_all_reduce(name##tname, type, type, op) +#define instantiate_col_reduce_small(name, itype, otype, op, dim) \ + instantiate_kernel("col_reduce_small_" #dim "_reduce_" #name, \ + col_reduce_small, \ + itype, otype, op, uint, dim) \ + instantiate_kernel("col_reduce_longcolumn_" #dim "_reduce_" #name, \ + col_reduce_longcolumn, \ + itype, otype, op, uint, dim) \ + instantiate_kernel("col_reduce_small_large_" #dim "_reduce_" #name, \ + col_reduce_small, \ + itype, otype, op, size_t, dim) \ + instantiate_kernel("col_reduce_longcolumn_large_" #dim "_reduce_" #name, \ + col_reduce_longcolumn, \ + itype, otype, op, size_t, dim) -instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_types) -instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_64b) +#define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \ + instantiate_kernel("col_reduce_looped_" #dim "_" #bm "_" #bn "_reduce_" #name, \ + col_reduce_looped, \ + itype, otype, op, uint, dim, bm, bn) \ + instantiate_kernel("col_reduce_looped_large_" #dim "_" #bm "_" #bn "_reduce_" #name, \ + col_reduce_looped, \ + itype, otype, op, size_t, dim, bm, bn) -instantiate_reduce_from_types(instantiate_all_reduce, and, bool, And) -instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or) - -// special case bool with larger output type -instantiate_all_reduce(sumbool_, bool, uint32_t, Sum) - -#define instantiate_col_reduce_small(name, itype, otype, op, dim) \ - instantiate_kernel("col_reduce_small_" #dim "_reduce_" #name, \ - col_reduce_small, \ - itype, otype, op, dim) \ - instantiate_kernel("col_reduce_longcolumn_" #dim "_reduce_" #name, \ - col_reduce_longcolumn, \ - itype, otype, op, dim) - -#define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \ - instantiate_kernel("col_reduce_looped_" #dim "_" #bm "_" #bn "_reduce_" #name, \ - col_reduce_looped, \ - itype, otype, op, dim, bm, bn) - -#define instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, bm, bn) \ - instantiate_kernel("col_reduce_2pass_" #dim "_" #bm "_" #bn "_reduce_" #name, \ - col_reduce_2pass, \ - itype, otype, op, dim, bm, bn) +#define instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, bm, bn) \ + instantiate_kernel("col_reduce_2pass_" #dim "_" #bm "_" #bn "_reduce_" #name, \ + col_reduce_2pass, \ + itype, otype, op, uint, dim, bm, bn) \ + instantiate_kernel("col_reduce_2pass_large_" #dim "_" #bm "_" #bn "_reduce_" #name, \ + col_reduce_2pass, \ + itype, otype, op, size_t, dim, bm, bn) #define instantiate_col_reduce_looped(name, itype, otype, op, dim) \ instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 32, 32) \ instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, 32, 32) #define instantiate_col_reduce_general(name, itype, otype, op) \ - instantiate_col_reduce_small(name, itype, otype, op, 0) \ instantiate_col_reduce_small(name, itype, otype, op, 1) \ instantiate_col_reduce_small(name, itype, otype, op, 2) \ - instantiate_col_reduce_small(name, itype, otype, op, 3) \ - instantiate_col_reduce_small(name, itype, otype, op, 4) \ - instantiate_col_reduce_looped(name, itype, otype, op, 0) \ + instantiate_col_reduce_small(name, itype, otype, op, 5) \ instantiate_col_reduce_looped(name, itype, otype, op, 1) \ instantiate_col_reduce_looped(name, itype, otype, op, 2) \ - instantiate_col_reduce_looped(name, itype, otype, op, 3) \ - instantiate_col_reduce_looped(name, itype, otype, op, 4) + instantiate_col_reduce_looped(name, itype, otype, op, 5) -#define instantiate_same_col_reduce_helper(name, tname, type, op) \ - instantiate_col_reduce_general(name##tname, type, type, op) +#define instantiate_row_reduce_small(name, itype, otype, op, dim) \ + instantiate_kernel("row_reduce_small_" #dim "_reduce_" #name, \ + row_reduce_small, \ + itype, otype, op, uint, dim) \ + instantiate_kernel("row_reduce_small_large_" #dim "_reduce_" #name, \ + row_reduce_small, \ + itype, otype, op, size_t, dim) -instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types) -instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_64b) - -instantiate_col_reduce_general(sumbool_, bool, uint32_t, Sum) -instantiate_reduce_from_types(instantiate_col_reduce_general, and, bool, And) -instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or) - -#define instantiate_row_reduce_small(name, itype, otype, op, dim) \ - instantiate_kernel("row_reduce_small_" #dim "_reduce_" #name, \ - row_reduce_small, \ - itype, otype, op, dim) - -#define instantiate_row_reduce_looped(name, itype, otype, op, dim) \ - instantiate_kernel("row_reduce_looped_" #dim "_reduce_" #name, \ - row_reduce_looped, \ - itype, otype, op, dim) +#define instantiate_row_reduce_looped(name, itype, otype, op, dim) \ + instantiate_kernel("row_reduce_looped_" #dim "_reduce_" #name, \ + row_reduce_looped, \ + itype, otype, op, uint, dim) \ + instantiate_kernel("row_reduce_looped_large_" #dim "_reduce_" #name, \ + row_reduce_looped, \ + itype, otype, op, size_t, dim) #define instantiate_row_reduce_general(name, itype, otype, op) \ - instantiate_row_reduce_small(name, itype, otype, op, 0) \ instantiate_row_reduce_small(name, itype, otype, op, 1) \ instantiate_row_reduce_small(name, itype, otype, op, 2) \ - instantiate_row_reduce_small(name, itype, otype, op, 3) \ - instantiate_row_reduce_small(name, itype, otype, op, 4) \ - instantiate_row_reduce_looped(name, itype, otype, op, 0) \ + instantiate_row_reduce_small(name, itype, otype, op, 5) \ instantiate_row_reduce_looped(name, itype, otype, op, 1) \ instantiate_row_reduce_looped(name, itype, otype, op, 2) \ - instantiate_row_reduce_looped(name, itype, otype, op, 3) \ - instantiate_row_reduce_looped(name, itype, otype, op, 4) \ + instantiate_row_reduce_looped(name, itype, otype, op, 5) \ instantiate_kernel("row_reduce_simple_" #name, \ row_reduce_simple, \ itype, otype, op) -#define instantiate_same_row_reduce_helper(name, tname, type, op) \ - instantiate_row_reduce_general(name##tname, type, type, op) +#define instantiate_reduce_functions(name, tname, itype, otype, op) \ + instantiate_all_reduce(name##tname, itype, otype, op) \ + instantiate_row_reduce_general(name##tname, itype, otype, op) \ + instantiate_col_reduce_general(name##tname, itype, otype, op) -instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_types) -instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_64b) +#define instantiate_and_or(name, op) \ + instantiate_reduce_functions(name, bool_, bool, bool, op) \ + instantiate_reduce_functions(name, int16, int16_t, bool, op) \ + instantiate_reduce_functions(name, int32, int32_t, bool, op) \ + instantiate_reduce_functions(name, int64, int64_t, bool, op) -instantiate_reduce_from_types(instantiate_row_reduce_general, and, bool, And) -instantiate_reduce_from_types(instantiate_row_reduce_general, or, bool, Or) +instantiate_and_or(and, And) +instantiate_and_or(or, Or) -instantiate_row_reduce_general(sumbool_, bool, uint32_t, Sum) +#define instantiate_sum_prod(name, op) \ + instantiate_reduce_functions(name, int8, int8_t, int32_t, op) \ + instantiate_reduce_functions(name, int16, int16_t, int32_t, op) \ + instantiate_reduce_functions(name, int32, int32_t, int32_t, op) \ + instantiate_reduce_functions(name, int64, int64_t, int64_t, op) \ + instantiate_reduce_functions(name, float16, float16_t, float16_t, op) \ + instantiate_reduce_functions(name, bfloat16, bfloat16_t, bfloat16_t, op) \ + instantiate_reduce_functions(name, float32, float, float, op) \ + instantiate_reduce_functions(name, complex64, complex64_t, complex64_t, op) + +instantiate_sum_prod(sum, Sum) +instantiate_sum_prod(prod, Prod) + +#define instantiate_min_max(name, op) \ + instantiate_reduce_functions(name, int8, int8_t, int8_t, op) \ + instantiate_reduce_functions(name, int16, int16_t, int16_t, op) \ + instantiate_reduce_functions(name, int32, int32_t, int32_t, op) \ + instantiate_reduce_functions(name, int64, int64_t, int64_t, op) \ + instantiate_reduce_functions(name, uint8, uint8_t, uint8_t, op) \ + instantiate_reduce_functions(name, uint16, uint16_t, uint16_t, op) \ + instantiate_reduce_functions(name, uint32, uint32_t, uint32_t, op) \ + instantiate_reduce_functions(name, uint64, uint64_t, uint64_t, op) \ + instantiate_reduce_functions(name, float16, float16_t, float16_t, op) \ + instantiate_reduce_functions(name, bfloat16, bfloat16_t, bfloat16_t, op) \ + instantiate_reduce_functions(name, float32, float, float, op) \ + instantiate_reduce_functions(name, complex64, complex64_t, complex64_t, op) + +instantiate_min_max(min, Min) +instantiate_min_max(max, Max) // clang-format on diff --git a/mlx/backend/metal/kernels/reduction/reduce_all.h b/mlx/backend/metal/kernels/reduction/reduce_all.h index 381d5e20b..e0d08392c 100644 --- a/mlx/backend/metal/kernels/reduction/reduce_all.h +++ b/mlx/backend/metal/kernels/reduction/reduce_all.h @@ -1,6 +1,11 @@ // Copyright © 2023-2024 Apple Inc. -template +template < + typename T, + typename U, + typename Op, + typename IdxT = int64_t, + int N_READS = REDUCE_N_READS> [[kernel]] void all_reduce( const device T* in [[buffer(0)]], device U* out [[buffer(1)]], @@ -16,10 +21,10 @@ template threadgroup U shared_vals[simd_size]; U total = Op::init; - int64_t start_idx = gid.y * row_size; - int64_t actual_row = + IdxT start_idx = gid.y * IdxT(row_size); + IdxT actual_row = (start_idx + row_size <= in_size) ? row_size : in_size - start_idx; - int64_t blocks = actual_row / (lsize.x * N_READS); + IdxT blocks = actual_row / (lsize.x * N_READS); int extra = actual_row - blocks * (lsize.x * N_READS); extra -= lid.x * N_READS; start_idx += lid.x * N_READS; @@ -30,7 +35,7 @@ template extra = 0; } - for (int64_t b = 0; b < blocks; b++) { + for (IdxT b = 0; b < blocks; b++) { for (int i = 0; i < N_READS; i++) { total = op(static_cast(in[i]), total); } diff --git a/mlx/backend/metal/kernels/reduction/reduce_col.h b/mlx/backend/metal/kernels/reduction/reduce_col.h index 735e80afe..2fa5132d9 100644 --- a/mlx/backend/metal/kernels/reduction/reduce_col.h +++ b/mlx/backend/metal/kernels/reduction/reduce_col.h @@ -1,6 +1,6 @@ // Copyright © 2023-2024 Apple Inc. -template +template [[kernel]] void col_reduce_small( const device T* in [[buffer(0)]], device U* out [[buffer(1)]], @@ -19,7 +19,7 @@ template uint3 lsize [[threads_per_threadgroup]]) { constexpr int n_reads = 4; Op op; - looped_elem_to_loc loop; + LoopedElemToLoc 2)> loop(reduce_ndim); const device T* row; U totals[n_reads]; @@ -27,20 +27,20 @@ template totals[i] = Op::init; } - size_t column = size_t(gid.x) * lsize.x * n_reads + lid.x * n_reads; + IdxT column = IdxT(gid.x) * lsize.x * n_reads + lid.x * n_reads; if (column >= reduction_stride) { return; } bool safe = column + n_reads <= reduction_stride; - size_t out_idx = gid.y + gsize.y * size_t(gid.z); - size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim); + IdxT out_idx = gid.y + gsize.y * IdxT(gid.z); + IdxT in_idx = elem_to_loc(out_idx, shape, strides, ndim); in += in_idx + column; - size_t total_rows = non_col_reductions * reduction_size; + IdxT total_rows = IdxT(non_col_reductions) * IdxT(reduction_size); loop.next(lid.y, reduce_shape, reduce_strides); - for (size_t r = lid.y; r < total_rows; r += lsize.y) { - row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); + for (IdxT r = lid.y; r < total_rows; r += lsize.y) { + row = in + loop.location(); if (safe) { for (int i = 0; i < n_reads; i++) { totals[i] = op(static_cast(row[i]), totals[i]); @@ -80,7 +80,7 @@ template } if (lid.y == 0) { - out += out_idx * reduction_stride + column; + out += out_idx * IdxT(reduction_stride) + column; if (safe) { for (int i = 0; i < n_reads; i++) { out[i] = totals[i]; @@ -93,7 +93,7 @@ template } } -template +template [[kernel]] void col_reduce_longcolumn( const device T* in [[buffer(0)]], device U* out [[buffer(1)]], @@ -112,19 +112,19 @@ template uint3 lid [[thread_position_in_threadgroup]], uint3 lsize [[threads_per_threadgroup]]) { Op op; - looped_elem_to_loc loop; + LoopedElemToLoc 2)> loop(reduce_ndim); const device T* row; - size_t out_idx = gid.x + gsize.x * size_t(gid.y); - size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim); + IdxT out_idx = gid.x + gsize.x * IdxT(gid.y); + IdxT in_idx = elem_to_loc(out_idx, shape, strides, ndim); in += in_idx + lid.x; U total = Op::init; - size_t total_rows = non_col_reductions * reduction_size; + IdxT total_rows = IdxT(non_col_reductions) * IdxT(reduction_size); loop.next(gid.z * lsize.y + lid.y, reduce_shape, reduce_strides); - for (size_t r = gid.z * lsize.y + lid.y; r < total_rows; + for (IdxT r = gid.z * lsize.y + lid.y; r < total_rows; r += lsize.y * gsize.z) { - row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); + row = in + loop.location(); total = op(static_cast(*row), total); loop.next(lsize.y * gsize.z, reduce_shape, reduce_strides); } @@ -136,7 +136,8 @@ template for (uint i = 1; i < lsize.y; i++) { total = op(total, shared_vals[i * lsize.x + lid.x]); } - out[gid.z * out_size + out_idx * reduction_stride + lid.x] = total; + out[gid.z * IdxT(out_size) + out_idx * IdxT(reduction_stride) + lid.x] = + total; } } @@ -151,7 +152,14 @@ template * totals with a loop. * 7. Write them to the output */ -template +template < + typename T, + typename U, + typename Op, + typename IdxT, + int NDIMS, + int BM, + int BN> [[kernel]] void col_reduce_looped( const device T* in [[buffer(0)]], device U* out [[buffer(1)]], @@ -176,7 +184,7 @@ template threadgroup U shared_vals[BN * BM]; U totals[n_reads]; - looped_elem_to_loc loop; + LoopedElemToLoc 2)> loop(reduce_ndim); const device T* row; for (int i = 0; i < n_reads; i++) { @@ -185,17 +193,17 @@ template short lid = simd_group_id * simd_size + simd_lane_id; short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks); - size_t column = BN * gid.x + offset.x; + IdxT column = BN * gid.x + offset.x; bool safe = column + n_reads <= reduction_stride; - size_t out_idx = gid.y + gsize.y * size_t(gid.z); - size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim); + IdxT out_idx = gid.y + gsize.y * IdxT(gid.z); + IdxT in_idx = elem_to_loc(out_idx, shape, strides, ndim); in += in_idx + column; - size_t total = non_col_reductions * reduction_size; + IdxT total = IdxT(non_col_reductions) * IdxT(reduction_size); loop.next(offset.y, reduce_shape, reduce_strides); - for (size_t r = offset.y; r < total; r += BM) { - row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); + for (IdxT r = offset.y; r < total; r += BM) { + row = in + loop.location(); if (safe) { for (int i = 0; i < n_reads; i++) { @@ -235,8 +243,8 @@ template // Write the output. if (simd_lane_id == 0) { - size_t out_column = BN * gid.x + out_offset.x; - out += out_idx * reduction_stride + out_column; + IdxT out_column = BN * gid.x + out_offset.x; + out += out_idx * IdxT(reduction_stride) + out_column; if (out_column + n_outputs <= reduction_stride) { for (int i = 0; i < n_outputs; i++) { out[i] = totals[i]; @@ -269,7 +277,7 @@ template // Write the output. if (offset.y == 0) { - out += out_idx * reduction_stride + column; + out += out_idx * IdxT(reduction_stride) + column; if (safe) { for (int i = 0; i < n_reads; i++) { out[i] = totals[i]; @@ -283,7 +291,14 @@ template } } -template +template < + typename T, + typename U, + typename Op, + typename IdxT, + int NDIMS, + int BM, + int BN> [[kernel]] void col_reduce_2pass( const device T* in [[buffer(0)]], device U* out [[buffer(1)]], @@ -312,7 +327,7 @@ template threadgroup U shared_vals[BN * BM]; U totals[n_reads]; - looped_elem_to_loc loop; + LoopedElemToLoc 2)> loop(reduce_ndim); const device T* row; for (int i = 0; i < n_reads; i++) { @@ -321,20 +336,19 @@ template short lid = simd_group_id * simd_size + simd_lane_id; short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks); - size_t column = BN * gid.x + offset.x; + IdxT column = BN * gid.x + offset.x; bool safe = column + n_reads <= reduction_stride; - size_t full_idx = gid.y + gsize.y * size_t(gid.z); - size_t block_idx = full_idx / out_size; - size_t out_idx = full_idx % out_size; - size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim); + IdxT full_idx = gid.y + gsize.y * IdxT(gid.z); + IdxT block_idx = full_idx / IdxT(out_size); + IdxT out_idx = full_idx % IdxT(out_size); + IdxT in_idx = elem_to_loc(out_idx, shape, strides, ndim); in += in_idx + column; - size_t total = non_col_reductions * reduction_size; + IdxT total = IdxT(non_col_reductions) * IdxT(reduction_size); loop.next(offset.y + block_idx * BM, reduce_shape, reduce_strides); - for (size_t r = offset.y + block_idx * BM; r < total; - r += outer_blocks * BM) { - row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); + for (IdxT r = offset.y + block_idx * BM; r < total; r += outer_blocks * BM) { + row = in + loop.location(); if (safe) { for (int i = 0; i < n_reads; i++) { @@ -369,8 +383,8 @@ template // Write the output. if (simd_lane_id == 0) { - size_t out_column = BN * gid.x + out_offset.x; - out += full_idx * reduction_stride + out_column; + IdxT out_column = BN * gid.x + out_offset.x; + out += full_idx * IdxT(reduction_stride) + out_column; if (out_column + n_outputs <= reduction_stride) { for (int i = 0; i < n_outputs; i++) { out[i] = totals[i]; diff --git a/mlx/backend/metal/kernels/reduction/reduce_row.h b/mlx/backend/metal/kernels/reduction/reduce_row.h index af8a01da7..746361255 100644 --- a/mlx/backend/metal/kernels/reduction/reduce_row.h +++ b/mlx/backend/metal/kernels/reduction/reduce_row.h @@ -193,6 +193,7 @@ template < typename T, typename U, typename Op, + typename IdxT, int NDIMS, int N_READS = REDUCE_N_READS> [[kernel]] void row_reduce_small( @@ -214,20 +215,20 @@ template < Op op; U total_val = Op::init; - looped_elem_to_loc loop; + LoopedElemToLoc 2)> loop(reduce_ndim); // Precompute some row reduction numbers const device T* row; - int blocks = row_size / N_READS; - int extra = row_size % N_READS; + int blocks = IdxT(row_size) / N_READS; + int extra = IdxT(row_size) % N_READS; if ((non_row_reductions < 32 && row_size <= 8) || non_row_reductions <= 8) { // Simple loop over non_row_reductions and reduce the row in the thread. - size_t out_idx = tid.x + tsize.y * size_t(tid.y); - in += elem_to_loc(out_idx, shape, strides, ndim); + IdxT out_idx = tid.x + tsize.y * IdxT(tid.y); + in += elem_to_loc(out_idx, shape, strides, ndim); for (uint r = 0; r < non_row_reductions; r++) { - row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); + row = in + loop.location(); thread_reduce(total_val, row, blocks, extra); loop.next(reduce_shape, reduce_strides); } @@ -236,13 +237,13 @@ template < } else { // Collaboratively reduce over non_row_reductions in the simdgroup. Each // thread reduces every 32nd row and then a simple simd reduce. - size_t out_idx = gid.y + gsize.y * size_t(gid.z); - in += elem_to_loc(out_idx, shape, strides, ndim); + IdxT out_idx = gid.y + gsize.y * IdxT(gid.z); + in += elem_to_loc(out_idx, shape, strides, ndim); loop.next(simd_lane_id, reduce_shape, reduce_strides); for (uint r = simd_lane_id; r < non_row_reductions; r += simd_size) { - row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); + row = in + loop.location(); thread_reduce(total_val, row, blocks, extra); loop.next(simd_size, reduce_shape, reduce_strides); } @@ -259,6 +260,7 @@ template < typename T, typename U, typename Op, + typename IdxT = size_t, int N_READS = REDUCE_N_READS, int N_WRITES = REDUCE_N_WRITES> [[kernel]] void row_reduce_simple( @@ -277,15 +279,15 @@ template < U totals[N_WRITES]; // Move to the row - size_t out_idx = N_WRITES * (gid.y + gsize.y * size_t(gid.z)); + IdxT out_idx = N_WRITES * (gid.y + gsize.y * IdxT(gid.z)); if (out_idx + N_WRITES > out_size) { out_idx = out_size - N_WRITES; } - in += out_idx * reduction_size; + in += out_idx * IdxT(reduction_size); out += out_idx; // Each thread reduces across the row - int blocks = reduction_size / (lsize.x * N_READS); + int blocks = IdxT(reduction_size) / (lsize.x * N_READS); int extra = reduction_size - blocks * (lsize.x * N_READS); per_thread_row_reduce( totals, in, reduction_size, blocks, extra, lsize.x, lid.x); @@ -306,6 +308,7 @@ template < typename T, typename U, typename Op, + typename IdxT, int NDIMS, int N_READS = REDUCE_N_READS> [[kernel]] void row_reduce_looped( @@ -330,19 +333,20 @@ template < threadgroup U shared_vals[simd_size]; U total = Op::init; - size_t out_idx = gid.y + gsize.y * size_t(gid.z); + IdxT out_idx = gid.y + gsize.y * IdxT(gid.z); // lid.x * N_READS breaks the per_thread_row_reduce interface a bit. Maybe it // needs a small refactor. - in += elem_to_loc(out_idx, shape, strides, ndim) + lid.x * N_READS; + in += elem_to_loc(out_idx, shape, strides, ndim) + + lid.x * N_READS; - looped_elem_to_loc loop; + LoopedElemToLoc 2)> loop(reduce_ndim); const device T* row; - int blocks = row_size / (lsize.x * N_READS); + int blocks = IdxT(row_size) / (lsize.x * N_READS); int extra = row_size - blocks * (lsize.x * N_READS); - for (size_t i = 0; i < non_row_reductions; i++) { - row = in + loop.location(i, reduce_shape, reduce_strides, reduce_ndim); + for (IdxT i = 0; i < non_row_reductions; i++) { + row = in + loop.location(); // Each thread reduces across the row U row_total; diff --git a/mlx/backend/metal/kernels/utils.h b/mlx/backend/metal/kernels/utils.h index 1e1b91c0b..b894426d4 100644 --- a/mlx/backend/metal/kernels/utils.h +++ b/mlx/backend/metal/kernels/utils.h @@ -204,16 +204,21 @@ METAL_FUNC vec elem_to_loc_3_nd( // Elem to loc in a loop utils /////////////////////////////////////////////////////////////////////////////// -template -struct looped_elem_to_loc { - looped_elem_to_loc inner_looper; - offset_t offset{0}; +template +struct LoopedElemToLoc { + int dim; + LoopedElemToLoc inner_looper; + OffsetT offset{0}; int index{0}; - void next(const constant int* shape, const constant size_t* strides) { - index++; - offset += strides[dim - 1]; + LoopedElemToLoc(int dim) : dim(dim), inner_looper(dim - 1) {} + void next(const constant int* shape, const constant size_t* strides) { + if (dim == 0) { + return; + } + index++; + offset += OffsetT(strides[dim - 1]); if (index >= shape[dim - 1]) { index = 0; inner_looper.next(shape, strides); @@ -222,13 +227,21 @@ struct looped_elem_to_loc { } void next(int n, const constant int* shape, const constant size_t* strides) { + if (dim == 0) { + return; + } index += n; - offset += n * strides[dim - 1]; + offset += n * OffsetT(strides[dim - 1]); if (index >= shape[dim - 1]) { int extra = index - shape[dim - 1]; + if (extra >= shape[dim - 1]) { + inner_looper.next(1 + extra / shape[dim - 1], shape, strides); + extra = extra % shape[dim - 1]; + } else { + inner_looper.next(shape, strides); + } index = 0; - inner_looper.next(shape, strides); offset = inner_looper.offset; if (extra > 0) { next(extra, shape, strides); @@ -236,44 +249,61 @@ struct looped_elem_to_loc { } } - offset_t - location(offset_t, const constant int*, const constant size_t*, int) { + OffsetT location() { return offset; } }; -template -struct looped_elem_to_loc<1, offset_t> { - offset_t offset{0}; +template +struct LoopedElemToLoc<1, OffsetT, true> { + int dim; + OffsetT offset{0}; + uint index{0}; + + LoopedElemToLoc(int dim) : dim(dim) {} + + void next(const constant int* shape, const constant size_t* strides) { + index++; + if (dim > 1) { + offset = elem_to_loc(index, shape, strides, dim); + } else { + offset += OffsetT(strides[0]); + } + } + + void next(int n, const constant int* shape, const constant size_t* strides) { + index += n; + if (dim > 1) { + offset = elem_to_loc(index, shape, strides, dim); + } else { + offset = index * OffsetT(strides[0]); + } + } + + OffsetT location() { + return offset; + } +}; + +template +struct LoopedElemToLoc<1, OffsetT, false> { + OffsetT offset{0}; + + LoopedElemToLoc(int) {} void next(const constant int*, const constant size_t* strides) { - offset += strides[0]; + offset += OffsetT(strides[0]); } void next(int n, const constant int*, const constant size_t* strides) { - offset += n * strides[0]; + offset += n * OffsetT(strides[0]); } - offset_t - location(offset_t, const constant int*, const constant size_t*, int) { + OffsetT location() { return offset; } }; -template -struct looped_elem_to_loc<0, offset_t> { - void next(const constant int*, const constant size_t*) {} - void next(int, const constant int*, const constant size_t*) {} - - offset_t location( - offset_t idx, - const constant int* shape, - const constant size_t* strides, - int ndim) { - return elem_to_loc(idx, shape, strides, ndim); - } -}; - /////////////////////////////////////////////////////////////////////////////// // Calculation utils /////////////////////////////////////////////////////////////////////////////// diff --git a/mlx/backend/metal/nojit_kernels.cpp b/mlx/backend/metal/nojit_kernels.cpp index b5d9f5fa2..03b31197b 100644 --- a/mlx/backend/metal/nojit_kernels.cpp +++ b/mlx/backend/metal/nojit_kernels.cpp @@ -1,3 +1,4 @@ +// Copyright © 2024 Apple Inc. #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" @@ -99,7 +100,7 @@ MTL::ComputePipelineState* get_reduce_init_kernel( const std::string& kernel_name, const std::string&, const std::string&, - const array&) { + const Dtype&) { return d.get_kernel(kernel_name); } @@ -108,8 +109,9 @@ MTL::ComputePipelineState* get_reduce_kernel( const std::string& kernel_name, const std::string&, const std::string&, - const array&, - const array&, + const Dtype&, + const Dtype&, + const std::string&, int, int, int) { diff --git a/mlx/backend/metal/reduce.cpp b/mlx/backend/metal/reduce.cpp index 960b1898e..15cbcc9af 100644 --- a/mlx/backend/metal/reduce.cpp +++ b/mlx/backend/metal/reduce.cpp @@ -2,7 +2,6 @@ #include #include -#include #include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" @@ -202,6 +201,16 @@ inline bool is_64b_dtype(Dtype dtype) { return dtype == int64 || dtype == uint64 || dtype == complex64; } +inline int get_kernel_reduce_ndim(int reduce_ndim) { + if (reduce_ndim <= 1) { + return 1; + } else if (reduce_ndim == 2) { + return 2; + } else { + return 5; + } +} + inline int threadgroup_size_from_row_size(int row_size) { // 1 simdgroup per row smallish rows if (row_size <= 512) { @@ -233,16 +242,51 @@ inline auto output_grid_for_col_reduce( return get_2d_grid_dims(out_shape, out_strides); } +std::pair remap_reduce_types( + const array& in, + const std::string& op_name) { + if (op_name == "sum" || op_name == "prod") { + if (issubdtype(in.dtype(), integer)) { + switch (in.dtype().size()) { + case 1: + return {int8, int32}; + case 2: + return {int16, int32}; + case 4: + return {int32, int32}; + case 8: + return {int64, int64}; + } + } + if (in.dtype() == bool_) { + return {int8, int32}; + } + return {in.dtype(), in.dtype()}; + } else if (op_name == "and" || op_name == "or") { + if (in.dtype().size() == 1) { + return {bool_, bool_}; + } else if (in.dtype().size() == 2) { + return {int16, bool_}; + } else if (in.dtype().size() == 4) { + return {int32, bool_}; + } else { + return {int64, bool_}; + } + } + return {in.dtype(), in.dtype()}; +} + void init_reduce( array& out, const std::string& op_name, CommandEncoder& compute_encoder, metal::Device& d, const Stream& s) { - std::ostringstream kname; + auto [_, out_type] = remap_reduce_types(out, op_name); const std::string func_name = "init_reduce"; - kname << func_name << "_" << op_name << type_to_name(out); - auto kernel = get_reduce_init_kernel(d, kname.str(), func_name, op_name, out); + std::string kname = func_name; + concatenate(kname, "_", op_name, type_to_name(out_type)); + auto kernel = get_reduce_init_kernel(d, kname, func_name, op_name, out_type); size_t nthreads = out.size(); MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); @@ -263,10 +307,12 @@ void all_reduce_dispatch( metal::Device& d, const Stream& s) { // Set the kernel - std::ostringstream kname; + auto [in_type, out_type] = remap_reduce_types(in, op_name); const std::string func_name = "all_reduce"; - kname << func_name << "_" << op_name << type_to_name(in); - auto kernel = get_reduce_kernel(d, kname.str(), func_name, op_name, in, out); + std::string kname = func_name; + concatenate(kname, "_", op_name, type_to_name(in_type)); + auto kernel = get_reduce_kernel( + d, kname, func_name, op_name, in_type, out_type, "int64_t"); compute_encoder.set_compute_pipeline_state(kernel); size_t in_size = in.size(); @@ -300,7 +346,7 @@ void all_reduce_dispatch( } // Allocate an intermediate tensor to hold results if needed - array intermediate({n_rows}, out.dtype(), nullptr, {}); + array intermediate({n_rows}, out_type, nullptr, {}); intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes())); d.add_temporary(intermediate, s.index); @@ -318,10 +364,10 @@ void all_reduce_dispatch( compute_encoder.dispatch_threads(grid_dims, group_dims); // 2nd pass - std::ostringstream kname_2nd_pass; - kname_2nd_pass << func_name << "_" << op_name << type_to_name(intermediate); + std::string kname_2nd_pass = func_name; + concatenate(kname_2nd_pass, "_", op_name, type_to_name(intermediate)); auto kernel_2nd_pass = get_reduce_kernel( - d, kname_2nd_pass.str(), func_name, op_name, intermediate, out); + d, kname_2nd_pass, func_name, op_name, out_type, out_type, "int64_t"); compute_encoder.set_compute_pipeline_state(kernel_2nd_pass); size_t intermediate_size = n_rows; grid_dims = MTL::Size(threadgroup_2nd_pass, 1, 1); @@ -343,12 +389,30 @@ void row_reduce_small( metal::Device& d, const Stream& s) { // Set the kernel - std::ostringstream kname; - int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0; + int n = get_kernel_reduce_ndim(args.reduce_ndim); + auto [in_type, out_type] = remap_reduce_types(in, op_name); const std::string func_name = "row_reduce_small"; - kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in); - auto kernel = - get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n); + std::string kname = func_name; + bool large = in.size() > UINT32_MAX; + if (large) { + kname += "_large"; + } + concatenate( + kname, + "_", + std::to_string(n), + "_reduce_", + op_name, + type_to_name(in_type)); + auto kernel = get_reduce_kernel( + d, + kname, + func_name, + op_name, + in_type, + out_type, + large ? "size_t" : "uint", + n); compute_encoder.set_compute_pipeline_state(kernel); // Figure out the grid dims @@ -381,10 +445,13 @@ void row_reduce_simple( metal::Device& d, const Stream& s) { // Set the kernel - std::ostringstream kname; + auto [in_type, out_type] = remap_reduce_types(in, op_name); const std::string func_name = "row_reduce_simple"; - kname << func_name << "_" << op_name << type_to_name(in); - auto kernel = get_reduce_kernel(d, kname.str(), func_name, op_name, in, out); + std::string kname = func_name; + concatenate(kname, "_", op_name, type_to_name(in_type)); + + auto kernel = get_reduce_kernel( + d, kname, func_name, op_name, in_type, out_type, "size_t"); compute_encoder.set_compute_pipeline_state(kernel); // Figure out the grid dims @@ -417,13 +484,32 @@ void row_reduce_looped( CommandEncoder& compute_encoder, metal::Device& d, const Stream& s) { + auto [in_type, out_type] = remap_reduce_types(in, op_name); + // Set the kernel - std::ostringstream kname; - int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0; + int n = get_kernel_reduce_ndim(args.reduce_ndim); const std::string func_name = "row_reduce_looped"; - kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in); - auto kernel = - get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n); + std::string kname = func_name; + bool large = in.size() > UINT32_MAX; + if (large) { + kname += "_large"; + } + concatenate( + kname, + "_", + std::to_string(n), + "_reduce_", + op_name, + type_to_name(in_type)); + auto kernel = get_reduce_kernel( + d, + kname, + func_name, + op_name, + in_type, + out_type, + large ? "size_t" : "uint", + n); compute_encoder.set_compute_pipeline_state(kernel); // Figure out the grid @@ -475,6 +561,8 @@ void strided_reduce_small( CommandEncoder& compute_encoder, metal::Device& d, const Stream& s) { + auto [in_type, out_type] = remap_reduce_types(in, op_name); + // Figure out the grid dims MTL::Size grid_dims, group_dims; @@ -483,12 +571,29 @@ void strided_reduce_small( args.reduce_strides.push_back(args.reduction_stride); args.reduce_ndim++; - int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0; - std::ostringstream kname; + int n = get_kernel_reduce_ndim(args.reduce_ndim); const std::string func_name = "col_reduce_small"; - kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in); - auto kernel = - get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n); + std::string kname = func_name; + bool large = in.size() > UINT32_MAX; + if (large) { + kname += "_large"; + } + concatenate( + kname, + "_", + std::to_string(n), + "_reduce_", + op_name, + type_to_name(in_type)); + auto kernel = get_reduce_kernel( + d, + kname, + func_name, + op_name, + in_type, + out_type, + large ? "size_t" : "uint", + n); compute_encoder.set_compute_pipeline_state(kernel); const int n_reads = 4; @@ -522,6 +627,7 @@ void strided_reduce_longcolumn( CommandEncoder& compute_encoder, metal::Device& d, const Stream& s) { + auto [in_type, out_type] = remap_reduce_types(in, op_name); size_t total_reduction_size = args.reduction_size * args.non_col_reductions; size_t outer_blocks = 32; if (total_reduction_size >= 32768) { @@ -534,7 +640,7 @@ void strided_reduce_longcolumn( intermediate_shape.push_back(outer_blocks); intermediate_shape.insert( intermediate_shape.end(), out.shape().begin(), out.shape().end()); - array intermediate(std::move(intermediate_shape), out.dtype(), nullptr, {}); + array intermediate(std::move(intermediate_shape), out_type, nullptr, {}); intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes())); d.add_temporary(intermediate, s.index); @@ -556,12 +662,29 @@ void strided_reduce_longcolumn( MTL::Size group_dims(threadgroup_x, threadgroup_y, 1); // Set the kernel - int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0; - std::ostringstream kname; - const std::string func_name = "col_reduce_longcolumn"; - kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in); - auto kernel = - get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n); + int n = get_kernel_reduce_ndim(args.reduce_ndim); + std::string func_name = "col_reduce_longcolumn"; + std::string kname = func_name; + bool large = in.size() > UINT32_MAX; + if (large) { + kname += "_large"; + } + concatenate( + kname, + "_", + std::to_string(n), + "_reduce_", + op_name, + type_to_name(in_type)); + auto kernel = get_reduce_kernel( + d, + kname, + func_name, + op_name, + in_type, + out_type, + large ? "size_t" : "uint", + n); compute_encoder.set_compute_pipeline_state(kernel); // Launch @@ -581,15 +704,21 @@ void strided_reduce_longcolumn( group_dims = MTL::Size(256, 1, 1); // Set the 2nd kernel - const std::string second_kernel = "col_reduce_looped_1_32_32_reduce_" + - op_name + type_to_name(intermediate); + func_name = "col_reduce_looped"; + kname = func_name; + large = intermediate.size() > UINT32_MAX; + if (large) { + kname += "_large"; + } + concatenate(kname, "_1_32_32_reduce_", op_name, type_to_name(intermediate)); kernel = get_reduce_kernel( d, - second_kernel, - "col_reduce_looped", + kname, + func_name, op_name, - intermediate, - out, + intermediate.dtype(), + out_type, + large ? "size_t" : "uint", 1, 32, 32); @@ -609,6 +738,8 @@ void strided_reduce_looped( CommandEncoder& compute_encoder, metal::Device& d, const Stream& s) { + auto [in_type, out_type] = remap_reduce_types(in, op_name); + // Prepare the arguments for the kernel args.reduce_shape.push_back(args.reduction_size); args.reduce_strides.push_back(args.reduction_stride); @@ -626,13 +757,35 @@ void strided_reduce_looped( MTL::Size group_dims(threadgroup_size, 1, 1); // Set the kernel - int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0; - std::ostringstream kname; - const std::string func_name = "col_reduce_looped"; - kname << func_name << "_" << n << "_" << BM << "_" << BN << "_reduce_" - << op_name << type_to_name(in); - auto kernel = - get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n, BM, BN); + int n = get_kernel_reduce_ndim(args.reduce_ndim); + std::string func_name = "col_reduce_looped"; + std::string kname = func_name; + bool large = in.size() > UINT32_MAX; + if (large) { + kname += "_large"; + } + concatenate( + kname, + "_", + std::to_string(n), + "_", + std::to_string(BM), + "_", + std::to_string(BN), + "_reduce_", + op_name, + type_to_name(in_type)); + auto kernel = get_reduce_kernel( + d, + kname, + func_name, + op_name, + in_type, + out_type, + large ? "size_t" : "uint", + n, + BM, + BN); compute_encoder.set_compute_pipeline_state(kernel); // Launch @@ -650,13 +803,15 @@ void strided_reduce_2pass( CommandEncoder& compute_encoder, metal::Device& d, const Stream& s) { + auto [in_type, out_type] = remap_reduce_types(in, op_name); + // Prepare the temporary accumulator std::vector intermediate_shape; intermediate_shape.reserve(out.ndim() + 1); intermediate_shape.push_back(32); intermediate_shape.insert( intermediate_shape.end(), out.shape().begin(), out.shape().end()); - array intermediate(std::move(intermediate_shape), out.dtype(), nullptr, {}); + array intermediate(std::move(intermediate_shape), out_type, nullptr, {}); intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes())); d.add_temporary(intermediate, s.index); @@ -679,13 +834,35 @@ void strided_reduce_2pass( MTL::Size group_dims(threadgroup_size, 1, 1); // Set the kernel - int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0; - std::ostringstream kname; - const std::string func_name = "col_reduce_2pass"; - kname << func_name << "_" << n << "_" << BM << "_" << BN << "_reduce_" - << op_name << type_to_name(in); - auto kernel = - get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n, BM, BN); + int n = get_kernel_reduce_ndim(args.reduce_ndim); + std::string func_name = "col_reduce_2pass"; + std::string kname = func_name; + bool large = in.size() > UINT32_MAX; + if (large) { + kname += "_large"; + } + concatenate( + kname, + "_", + std::to_string(n), + "_", + std::to_string(BM), + "_", + std::to_string(BN), + "_reduce_", + op_name, + type_to_name(in_type)); + auto kernel = get_reduce_kernel( + d, + kname, + func_name, + op_name, + in_type, + out_type, + large ? "size_t" : "uint", + n, + BM, + BN); compute_encoder.set_compute_pipeline_state(kernel); // Launch @@ -703,15 +880,21 @@ void strided_reduce_2pass( grid_dims = MTL::Size(threadgroup_size * ((out.size() + BN - 1) / BN), 1, 1); // Set the 2nd kernel - const std::string second_kernel = "col_reduce_looped_1_32_32_reduce_" + - op_name + type_to_name(intermediate); + func_name = "col_reduce_looped"; + kname = func_name; + large = intermediate.size() > UINT32_MAX; + if (large) { + kname += "_large"; + } + concatenate(kname, "_1_32_32_reduce_", op_name, type_to_name(intermediate)); kernel = get_reduce_kernel( d, - second_kernel, - "col_reduce_looped", + kname, + func_name, op_name, - intermediate, - out, + intermediate.dtype(), + out_type, + large ? "size_t" : "uint", 1, 32, 32); @@ -780,7 +963,7 @@ void Reduce::eval_gpu(const std::vector& inputs, array& out) { op_name = "sum"; break; case Reduce::Prod: - op_name = out.dtype() == bool_ ? "and" : "prod"; + op_name = "prod"; break; case Reduce::Min: op_name = out.dtype() == bool_ ? "and" : "min"; diff --git a/mlx/backend/metal/utils.cpp b/mlx/backend/metal/utils.cpp index deff629eb..22beb5d43 100644 --- a/mlx/backend/metal/utils.cpp +++ b/mlx/backend/metal/utils.cpp @@ -6,9 +6,9 @@ using namespace mlx; namespace mlx::core { -std::string type_to_name(const array& a) { +std::string type_to_name(const Dtype& t) { std::string tname; - switch (a.dtype()) { + switch (t) { case bool_: tname = "bool_"; break; @@ -52,6 +52,10 @@ std::string type_to_name(const array& a) { return tname; } +std::string type_to_name(const array& a) { + return type_to_name(a.dtype()); +} + MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2 /* = 10 */) { int pows[3] = {0, 0, 0}; int sum = 0; diff --git a/mlx/backend/metal/utils.h b/mlx/backend/metal/utils.h index 366da6287..082e4d116 100644 --- a/mlx/backend/metal/utils.h +++ b/mlx/backend/metal/utils.h @@ -8,6 +8,7 @@ namespace mlx::core { +std::string type_to_name(const Dtype& t); std::string type_to_name(const array& a); // Compute the thread block dimensions which fit the given diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 4193b08e0..d02ff4b48 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1615,7 +1615,14 @@ array sum( } auto [out_shape, sorted_axes, squeezed_shape, is_noop] = compute_reduce_shape(axes, a.shape()); - auto out_type = a.dtype() == bool_ ? int32 : a.dtype(); + Dtype out_type = a.dtype(); + if (issubdtype(a.dtype(), signedinteger)) { + out_type = a.dtype().size() <= 4 ? int32 : int64; + } else if (issubdtype(a.dtype(), unsignedinteger)) { + out_type = a.dtype().size() <= 4 ? uint32 : uint64; + } else if (a.dtype() == bool_) { + out_type = int32; + } auto out = (is_noop) ? astype(a, out_type, s) : array( @@ -1760,11 +1767,19 @@ array prod( } auto [out_shape, sorted_axes, squeezed_shape, is_noop] = compute_reduce_shape(axes, a.shape()); + Dtype out_type = a.dtype(); + if (issubdtype(a.dtype(), signedinteger)) { + out_type = a.dtype().size() <= 4 ? int32 : int64; + } else if (issubdtype(a.dtype(), unsignedinteger)) { + out_type = a.dtype().size() <= 4 ? uint32 : uint64; + } else if (a.dtype() == bool_) { + out_type = int32; + } auto out = (is_noop) ? a : array( std::move(out_shape), - a.dtype(), + out_type, std::make_shared(to_stream(s), Reduce::Prod, sorted_axes), {a}); if (!keepdims) { diff --git a/python/tests/test_reduce.py b/python/tests/test_reduce.py index 368b2128a..9012216ba 100644 --- a/python/tests/test_reduce.py +++ b/python/tests/test_reduce.py @@ -131,6 +131,28 @@ class TestReduce(mlx_tests.MLXTestCase): mxsum = y.sum().item() self.assertEqual(npsum, mxsum) + def test_many_reduction_axes(self): + + def check(x, axes): + expected = x + for ax in axes: + expected = mx.sum(expected, axis=ax, keepdims=True) + out = mx.sum(x, axis=axes, keepdims=True) + self.assertTrue(mx.array_equal(out, expected)) + + x = mx.random.randint(0, 10, shape=(4, 4, 4, 4, 4)) + check(x, (0, 2, 4)) + + x = mx.random.randint(0, 10, shape=(4, 4, 4, 4, 4, 4, 4)) + check(x, (0, 2, 4, 6)) + + x = mx.random.randint(0, 10, shape=(4, 4, 4, 4, 4, 4, 4, 4, 4)) + check(x, (0, 2, 4, 6, 8)) + + x = mx.random.randint(0, 10, shape=(4, 4, 4, 4, 4, 4, 4, 4, 4, 128)) + x = x.transpose(1, 0, 2, 3, 4, 5, 6, 7, 8, 9) + check(x, (1, 3, 5, 7, 9)) + if __name__ == "__main__": unittest.main(failfast=True)