Reduce specializations (#1607)

* start of reduce specializations

* fix all reduce

* fix many dims

* fix

* non-jit tests clear

* cleanup instantiations

* cpu merges

* change dim specializations

* optimize

* fix jit

* fix jit

* use higher precision for integer sum+prod

* fixes
This commit is contained in:
Awni Hannun 2024-11-21 19:53:00 -08:00 committed by GitHub
parent dcca0d7477
commit 0c5eea226b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 733 additions and 406 deletions

View File

@ -120,48 +120,56 @@ struct MinReduce {
}; };
template <typename InT> template <typename InT>
void reduce_dispatch_out( void reduce_dispatch_and_or(
const array& in, const array& in,
array& out, array& out,
Reduce::ReduceType rtype, Reduce::ReduceType rtype,
const std::vector<int>& axes) { const std::vector<int>& axes) {
switch (rtype) { if (rtype == Reduce::And) {
case Reduce::And: { reduction_op<InT, bool>(in, out, axes, true, AndReduce());
reduction_op<InT, bool>(in, out, axes, true, AndReduce()); } else {
break; reduction_op<InT, bool>(in, out, axes, false, OrReduce());
}
}
template <typename InT>
void reduce_dispatch_sum_prod(
const array& in,
array& out,
Reduce::ReduceType rtype,
const std::vector<int>& axes) {
if (rtype == Reduce::Sum) {
auto op = [](auto y, auto x) { (*y) = (*y) + x; };
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
reduction_op<InT, int32_t>(in, out, axes, 0, op);
} else {
reduction_op<InT, InT>(in, out, axes, 0, op);
} }
case Reduce::Or: { } else {
reduction_op<InT, bool>(in, out, axes, false, OrReduce()); auto op = [](auto y, auto x) { (*y) *= x; };
break; if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
} reduction_op<InT, int32_t>(in, out, axes, 1, op);
case Reduce::Sum: { } else {
auto op = [](auto y, auto x) { (*y) = (*y) + x; };
if (out.dtype() == int32) {
// special case since the input type can be bool
reduction_op<InT, int32_t>(in, out, axes, 0, op);
} else {
reduction_op<InT, InT>(in, out, axes, 0, op);
}
break;
}
case Reduce::Prod: {
auto op = [](auto y, auto x) { (*y) *= x; };
reduction_op<InT, InT>(in, out, axes, 1, op); reduction_op<InT, InT>(in, out, axes, 1, op);
break;
}
case Reduce::Max: {
auto init = Limits<InT>::min;
reduction_op<InT, InT>(in, out, axes, init, MaxReduce());
break;
}
case Reduce::Min: {
auto init = Limits<InT>::max;
reduction_op<InT, InT>(in, out, axes, init, MinReduce());
break;
} }
} }
} }
template <typename InT>
void reduce_dispatch_min_max(
const array& in,
array& out,
Reduce::ReduceType rtype,
const std::vector<int>& axes) {
if (rtype == Reduce::Max) {
auto init = Limits<InT>::min;
reduction_op<InT, InT>(in, out, axes, init, MaxReduce());
} else {
auto init = Limits<InT>::max;
reduction_op<InT, InT>(in, out, axes, init, MinReduce());
}
}
} // namespace } // namespace
void nd_loop( void nd_loop(
@ -190,46 +198,114 @@ void nd_loop(
void Reduce::eval(const std::vector<array>& inputs, array& out) { void Reduce::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
switch (in.dtype()) { switch (reduce_type_) {
case bool_: case Reduce::And:
reduce_dispatch_out<bool>(in, out, reduce_type_, axes_); case Reduce::Or: {
switch (in.dtype()) {
case bool_:
case uint8:
case int8:
reduce_dispatch_and_or<int8_t>(in, out, reduce_type_, axes_);
break;
case int16:
case uint16:
case float16:
case bfloat16:
reduce_dispatch_and_or<int16_t>(in, out, reduce_type_, axes_);
break;
case uint32:
case int32:
case float32:
reduce_dispatch_and_or<int32_t>(in, out, reduce_type_, axes_);
break;
case uint64:
case int64:
case complex64:
reduce_dispatch_and_or<int64_t>(in, out, reduce_type_, axes_);
break;
}
break; break;
case uint8: }
reduce_dispatch_out<uint8_t>(in, out, reduce_type_, axes_); case Reduce::Sum:
case Reduce::Prod: {
switch (in.dtype()) {
case bool_:
case uint8:
case int8:
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
break;
case int16:
case uint16:
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
break;
case int32:
case uint32:
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
break;
case int64:
case uint64:
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
break;
case float16:
reduce_dispatch_sum_prod<float16_t>(in, out, reduce_type_, axes_);
break;
case bfloat16:
reduce_dispatch_sum_prod<bfloat16_t>(in, out, reduce_type_, axes_);
break;
case float32:
reduce_dispatch_sum_prod<float>(in, out, reduce_type_, axes_);
break;
case complex64:
reduce_dispatch_sum_prod<complex64_t>(in, out, reduce_type_, axes_);
break;
}
break; break;
case uint16: }
reduce_dispatch_out<uint16_t>(in, out, reduce_type_, axes_); case Reduce::Max:
break; case Reduce::Min: {
case uint32: switch (in.dtype()) {
reduce_dispatch_out<uint32_t>(in, out, reduce_type_, axes_); case bool_:
break; reduce_dispatch_min_max<bool>(in, out, reduce_type_, axes_);
case uint64: break;
reduce_dispatch_out<uint64_t>(in, out, reduce_type_, axes_); case uint8:
break; reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
case int8: break;
reduce_dispatch_out<uint8_t>(in, out, reduce_type_, axes_); case uint16:
break; reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
case int16: break;
reduce_dispatch_out<uint16_t>(in, out, reduce_type_, axes_); case uint32:
break; reduce_dispatch_min_max<uint32_t>(in, out, reduce_type_, axes_);
case int32: break;
reduce_dispatch_out<int32_t>(in, out, reduce_type_, axes_); case uint64:
break; reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_);
case int64: break;
reduce_dispatch_out<int64_t>(in, out, reduce_type_, axes_); case int8:
break; reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
case float16: break;
reduce_dispatch_out<float16_t>(in, out, reduce_type_, axes_); case int16:
break; reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
case float32: break;
reduce_dispatch_out<float>(in, out, reduce_type_, axes_); case int32:
break; reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_);
case bfloat16: break;
reduce_dispatch_out<bfloat16_t>(in, out, reduce_type_, axes_); case int64:
break; reduce_dispatch_min_max<int64_t>(in, out, reduce_type_, axes_);
case complex64: break;
reduce_dispatch_out<complex64_t>(in, out, reduce_type_, axes_); case float16:
reduce_dispatch_min_max<float16_t>(in, out, reduce_type_, axes_);
break;
case float32:
reduce_dispatch_min_max<float>(in, out, reduce_type_, axes_);
break;
case bfloat16:
reduce_dispatch_min_max<bfloat16_t>(in, out, reduce_type_, axes_);
break;
case complex64:
reduce_dispatch_min_max<complex64_t>(in, out, reduce_type_, axes_);
break;
}
break; break;
}
} }
} }

View File

@ -1,5 +1,4 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/jit/arange.h" #include "mlx/backend/metal/jit/arange.h"
#include "mlx/backend/metal/jit/gemv_masked.h" #include "mlx/backend/metal/jit/gemv_masked.h"
@ -338,17 +337,17 @@ MTL::ComputePipelineState* get_reduce_init_kernel(
const std::string& kernel_name, const std::string& kernel_name,
const std::string& func_name, const std::string& func_name,
const std::string& op_name, const std::string& op_name,
const array& out) { const Dtype& out_type) {
auto lib = d.get_library(kernel_name, [&]() { auto lib = d.get_library(kernel_name, [&]() {
std::ostringstream kernel_source;
std::string op_type = op_name; std::string op_type = op_name;
op_type[0] = std::toupper(op_name[0]); op_type[0] = std::toupper(op_name[0]);
auto out_type = get_type_string(out.dtype()); auto out_t = get_type_string(out_type);
std::string op = op_type + "<" + out_type + ">"; std::string op = op_type + "<" + out_t + ">";
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce(); std::string kernel_source = metal::utils();
kernel_source << get_template_definition( kernel_source += metal::reduce_utils();
kernel_name, func_name, out_type, op); kernel_source += metal::reduce();
return kernel_source.str(); kernel_source += get_template_definition(kernel_name, func_name, out_t, op);
return kernel_source;
}); });
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
@ -358,30 +357,31 @@ MTL::ComputePipelineState* get_reduce_kernel(
const std::string& kernel_name, const std::string& kernel_name,
const std::string& func_name, const std::string& func_name,
const std::string& op_name, const std::string& op_name,
const array& in, const Dtype& in_type,
const array& out, const Dtype& out_type,
const std::string& idx_t,
int ndim /* = -1 */, int ndim /* = -1 */,
int bm /* = -1 */, int bm /* = -1 */,
int bn /* = -1 */) { int bn /* = -1 */) {
auto lib = d.get_library(kernel_name, [&]() { auto lib = d.get_library(kernel_name, [&]() {
std::string op_type = op_name; std::string op_type = op_name;
op_type[0] = std::toupper(op_name[0]); op_type[0] = std::toupper(op_name[0]);
std::ostringstream kernel_source; auto in_t = get_type_string(in_type);
auto in_type = get_type_string(in.dtype()); auto out_t = get_type_string(out_type);
auto out_type = get_type_string(out.dtype()); std::string op = op_type + "<" + out_t + ">";
std::string op = op_type + "<" + out_type + ">"; std::string kernel_source = metal::utils();
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce(); concatenate(kernel_source, metal::reduce_utils(), metal::reduce());
if (bm >= 0) { if (bm >= 0) {
kernel_source << get_template_definition( kernel_source += get_template_definition(
kernel_name, func_name, in_type, out_type, op, ndim, bm, bn); kernel_name, func_name, in_t, out_t, op, idx_t, ndim, bm, bn);
} else if (ndim >= 0) { } else if (ndim >= 0) {
kernel_source << get_template_definition( kernel_source += get_template_definition(
kernel_name, func_name, in_type, out_type, op, ndim); kernel_name, func_name, in_t, out_t, op, idx_t, ndim);
} else { } else {
kernel_source << get_template_definition( kernel_source += get_template_definition(
kernel_name, func_name, in_type, out_type, op); kernel_name, func_name, in_t, out_t, op, idx_t);
} }
return kernel_source.str(); return kernel_source;
}); });
auto st = d.get_kernel(kernel_name, lib); auto st = d.get_kernel(kernel_name, lib);
return st; return st;

View File

@ -81,15 +81,16 @@ MTL::ComputePipelineState* get_reduce_init_kernel(
const std::string& kernel_name, const std::string& kernel_name,
const std::string& func_name, const std::string& func_name,
const std::string& op_name, const std::string& op_name,
const array& out); const Dtype& out_type);
MTL::ComputePipelineState* get_reduce_kernel( MTL::ComputePipelineState* get_reduce_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,
const std::string& func_name, const std::string& func_name,
const std::string& op_name, const std::string& op_name,
const array& in, const Dtype& in_type,
const array& out, const Dtype& out_type,
const std::string& idx_t,
int ndim = -1, int ndim = -1,
int bm = -1, int bm = -1,
int bn = -1); int bn = -1);

View File

@ -10,186 +10,156 @@
#include "mlx/backend/metal/kernels/reduction/ops.h" #include "mlx/backend/metal/kernels/reduction/ops.h"
#include "mlx/backend/metal/kernels/reduce.h" #include "mlx/backend/metal/kernels/reduce.h"
#define instantiate_reduce_helper_floats(inst_f, name, op) \ #define instantiate_init_reduce(name, tname, type, op) \
inst_f(name, float16, half, op) \ instantiate_kernel("init_reduce_" #name #tname, init_reduce, type, op<type>)
inst_f(name, float32, float, op) \
inst_f(name, bfloat16, bfloat16_t, op)
#define instantiate_reduce_helper_uints(inst_f, name, op) \ instantiate_init_reduce(and, bool_, bool, And)
inst_f(name, uint8, uint8_t, op) \ instantiate_init_reduce(or, bool_, bool, Or)
inst_f(name, uint16, uint16_t, op) \
inst_f(name, uint32, uint32_t, op)
#define instantiate_reduce_helper_ints(inst_f, name, op) \ #define instantiate_init_sum_prod(name, op) \
inst_f(name, int8, int8_t, op) \ instantiate_init_reduce(name, int32, int32_t, op) \
inst_f(name, int16, int16_t, op) \ instantiate_init_reduce(name, int64, int64_t, op) \
inst_f(name, int32, int32_t, op) instantiate_init_reduce(name, float16, float16_t, op) \
instantiate_init_reduce(name, bfloat16, bfloat16_t, op) \
instantiate_init_reduce(name, float32, float, op) \
instantiate_init_reduce(name, complex64, complex64_t, op)
#define instantiate_reduce_helper_64b(inst_f, name, op) \ instantiate_init_sum_prod(sum, Sum)
inst_f(name, int64, int64_t, op) \ instantiate_init_sum_prod(prod, Prod)
inst_f(name, uint64, uint64_t, op) \
inst_f(name, complex64, complex64_t, op)
#define instantiate_reduce_helper_types(inst_f, name, op) \ #define instantiate_init_min_max(name, op) \
instantiate_reduce_helper_floats(inst_f, name, op) \ instantiate_init_reduce(name, bool_, bool, op) \
instantiate_reduce_helper_uints(inst_f, name, op) \ instantiate_init_reduce(name, int8, int8_t, op) \
instantiate_reduce_helper_ints(inst_f, name, op) instantiate_init_reduce(name, int16, int16_t, op) \
instantiate_init_reduce(name, int32, int32_t, op) \
instantiate_init_reduce(name, int64, int64_t, op) \
instantiate_init_reduce(name, uint8, uint8_t, op) \
instantiate_init_reduce(name, uint16, uint16_t, op) \
instantiate_init_reduce(name, uint32, uint32_t, op) \
instantiate_init_reduce(name, uint64, uint64_t, op) \
instantiate_init_reduce(name, float16, float16_t, op) \
instantiate_init_reduce(name, bfloat16, bfloat16_t, op) \
instantiate_init_reduce(name, float32, float, op) \
instantiate_init_reduce(name, complex64, complex64_t, op)
#define instantiate_reduce_ops(inst_f, type_f) \ instantiate_init_min_max(min, Min)
type_f(inst_f, sum, Sum) \ instantiate_init_min_max(max, Max)
type_f(inst_f, prod, Prod) \
type_f(inst_f, min, Min) \
type_f(inst_f, max, Max)
// Special case for bool reductions
#define instantiate_reduce_from_types_helper( \
inst_f, name, tname, itype, otype, op) \
inst_f(name##tname, itype, otype, op)
#define instantiate_reduce_from_types(inst_f, name, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, bool_, bool, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, uint8, uint8_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, uint16, uint16_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, uint32, uint32_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, uint64, uint64_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int8, int8_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int16, int16_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int32, int32_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int64, int64_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, float16, half, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, \
name, \
float32, \
float, \
otype, \
op) \
instantiate_reduce_from_types_helper( \
inst_f, \
name, \
bfloat16, \
bfloat16_t, \
otype, \
op)
#define instantiate_init_reduce(name, otype, op) \
instantiate_kernel("init_reduce_" #name, \
init_reduce, \
otype, op)
#define instantiate_init_reduce_helper(name, tname, type, op) \
instantiate_init_reduce(name##tname, type, op<type>)
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_64b)
instantiate_init_reduce(andbool_, bool, And<bool>)
instantiate_init_reduce(orbool_, bool, Or<bool>)
#define instantiate_all_reduce(name, itype, otype, op) \ #define instantiate_all_reduce(name, itype, otype, op) \
instantiate_kernel("all_reduce_" #name, \ instantiate_kernel("all_reduce_" #name, \
all_reduce, \ all_reduce, \
itype, otype, op) itype, otype, op)
#define instantiate_same_all_reduce_helper(name, tname, type, op) \ #define instantiate_col_reduce_small(name, itype, otype, op, dim) \
instantiate_all_reduce(name##tname, type, type, op<type>) instantiate_kernel("col_reduce_small_" #dim "_reduce_" #name, \
col_reduce_small, \
itype, otype, op, uint, dim) \
instantiate_kernel("col_reduce_longcolumn_" #dim "_reduce_" #name, \
col_reduce_longcolumn, \
itype, otype, op, uint, dim) \
instantiate_kernel("col_reduce_small_large_" #dim "_reduce_" #name, \
col_reduce_small, \
itype, otype, op, size_t, dim) \
instantiate_kernel("col_reduce_longcolumn_large_" #dim "_reduce_" #name, \
col_reduce_longcolumn, \
itype, otype, op, size_t, dim)
instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_types) #define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \
instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_64b) instantiate_kernel("col_reduce_looped_" #dim "_" #bm "_" #bn "_reduce_" #name, \
col_reduce_looped, \
itype, otype, op, uint, dim, bm, bn) \
instantiate_kernel("col_reduce_looped_large_" #dim "_" #bm "_" #bn "_reduce_" #name, \
col_reduce_looped, \
itype, otype, op, size_t, dim, bm, bn)
instantiate_reduce_from_types(instantiate_all_reduce, and, bool, And<bool>) #define instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, bm, bn) \
instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or<bool>) instantiate_kernel("col_reduce_2pass_" #dim "_" #bm "_" #bn "_reduce_" #name, \
col_reduce_2pass, \
// special case bool with larger output type itype, otype, op, uint, dim, bm, bn) \
instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>) instantiate_kernel("col_reduce_2pass_large_" #dim "_" #bm "_" #bn "_reduce_" #name, \
col_reduce_2pass, \
#define instantiate_col_reduce_small(name, itype, otype, op, dim) \ itype, otype, op, size_t, dim, bm, bn)
instantiate_kernel("col_reduce_small_" #dim "_reduce_" #name, \
col_reduce_small, \
itype, otype, op, dim) \
instantiate_kernel("col_reduce_longcolumn_" #dim "_reduce_" #name, \
col_reduce_longcolumn, \
itype, otype, op, dim)
#define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \
instantiate_kernel("col_reduce_looped_" #dim "_" #bm "_" #bn "_reduce_" #name, \
col_reduce_looped, \
itype, otype, op, dim, bm, bn)
#define instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, bm, bn) \
instantiate_kernel("col_reduce_2pass_" #dim "_" #bm "_" #bn "_reduce_" #name, \
col_reduce_2pass, \
itype, otype, op, dim, bm, bn)
#define instantiate_col_reduce_looped(name, itype, otype, op, dim) \ #define instantiate_col_reduce_looped(name, itype, otype, op, dim) \
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 32, 32) \ instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 32, 32) \
instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, 32, 32) instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, 32, 32)
#define instantiate_col_reduce_general(name, itype, otype, op) \ #define instantiate_col_reduce_general(name, itype, otype, op) \
instantiate_col_reduce_small(name, itype, otype, op, 0) \
instantiate_col_reduce_small(name, itype, otype, op, 1) \ instantiate_col_reduce_small(name, itype, otype, op, 1) \
instantiate_col_reduce_small(name, itype, otype, op, 2) \ instantiate_col_reduce_small(name, itype, otype, op, 2) \
instantiate_col_reduce_small(name, itype, otype, op, 3) \ instantiate_col_reduce_small(name, itype, otype, op, 5) \
instantiate_col_reduce_small(name, itype, otype, op, 4) \
instantiate_col_reduce_looped(name, itype, otype, op, 0) \
instantiate_col_reduce_looped(name, itype, otype, op, 1) \ instantiate_col_reduce_looped(name, itype, otype, op, 1) \
instantiate_col_reduce_looped(name, itype, otype, op, 2) \ instantiate_col_reduce_looped(name, itype, otype, op, 2) \
instantiate_col_reduce_looped(name, itype, otype, op, 3) \ instantiate_col_reduce_looped(name, itype, otype, op, 5)
instantiate_col_reduce_looped(name, itype, otype, op, 4)
#define instantiate_same_col_reduce_helper(name, tname, type, op) \ #define instantiate_row_reduce_small(name, itype, otype, op, dim) \
instantiate_col_reduce_general(name##tname, type, type, op<type>) instantiate_kernel("row_reduce_small_" #dim "_reduce_" #name, \
row_reduce_small, \
itype, otype, op, uint, dim) \
instantiate_kernel("row_reduce_small_large_" #dim "_reduce_" #name, \
row_reduce_small, \
itype, otype, op, size_t, dim)
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types) #define instantiate_row_reduce_looped(name, itype, otype, op, dim) \
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_64b) instantiate_kernel("row_reduce_looped_" #dim "_reduce_" #name, \
row_reduce_looped, \
instantiate_col_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>) itype, otype, op, uint, dim) \
instantiate_reduce_from_types(instantiate_col_reduce_general, and, bool, And<bool>) instantiate_kernel("row_reduce_looped_large_" #dim "_reduce_" #name, \
instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or<bool>) row_reduce_looped, \
itype, otype, op, size_t, dim)
#define instantiate_row_reduce_small(name, itype, otype, op, dim) \
instantiate_kernel("row_reduce_small_" #dim "_reduce_" #name, \
row_reduce_small, \
itype, otype, op, dim)
#define instantiate_row_reduce_looped(name, itype, otype, op, dim) \
instantiate_kernel("row_reduce_looped_" #dim "_reduce_" #name, \
row_reduce_looped, \
itype, otype, op, dim)
#define instantiate_row_reduce_general(name, itype, otype, op) \ #define instantiate_row_reduce_general(name, itype, otype, op) \
instantiate_row_reduce_small(name, itype, otype, op, 0) \
instantiate_row_reduce_small(name, itype, otype, op, 1) \ instantiate_row_reduce_small(name, itype, otype, op, 1) \
instantiate_row_reduce_small(name, itype, otype, op, 2) \ instantiate_row_reduce_small(name, itype, otype, op, 2) \
instantiate_row_reduce_small(name, itype, otype, op, 3) \ instantiate_row_reduce_small(name, itype, otype, op, 5) \
instantiate_row_reduce_small(name, itype, otype, op, 4) \
instantiate_row_reduce_looped(name, itype, otype, op, 0) \
instantiate_row_reduce_looped(name, itype, otype, op, 1) \ instantiate_row_reduce_looped(name, itype, otype, op, 1) \
instantiate_row_reduce_looped(name, itype, otype, op, 2) \ instantiate_row_reduce_looped(name, itype, otype, op, 2) \
instantiate_row_reduce_looped(name, itype, otype, op, 3) \ instantiate_row_reduce_looped(name, itype, otype, op, 5) \
instantiate_row_reduce_looped(name, itype, otype, op, 4) \
instantiate_kernel("row_reduce_simple_" #name, \ instantiate_kernel("row_reduce_simple_" #name, \
row_reduce_simple, \ row_reduce_simple, \
itype, otype, op) itype, otype, op)
#define instantiate_same_row_reduce_helper(name, tname, type, op) \ #define instantiate_reduce_functions(name, tname, itype, otype, op) \
instantiate_row_reduce_general(name##tname, type, type, op<type>) instantiate_all_reduce(name##tname, itype, otype, op<otype>) \
instantiate_row_reduce_general(name##tname, itype, otype, op<otype>) \
instantiate_col_reduce_general(name##tname, itype, otype, op<otype>)
instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_types) #define instantiate_and_or(name, op) \
instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_64b) instantiate_reduce_functions(name, bool_, bool, bool, op) \
instantiate_reduce_functions(name, int16, int16_t, bool, op) \
instantiate_reduce_functions(name, int32, int32_t, bool, op) \
instantiate_reduce_functions(name, int64, int64_t, bool, op)
instantiate_reduce_from_types(instantiate_row_reduce_general, and, bool, And<bool>) instantiate_and_or(and, And)
instantiate_reduce_from_types(instantiate_row_reduce_general, or, bool, Or<bool>) instantiate_and_or(or, Or)
instantiate_row_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>) #define instantiate_sum_prod(name, op) \
instantiate_reduce_functions(name, int8, int8_t, int32_t, op) \
instantiate_reduce_functions(name, int16, int16_t, int32_t, op) \
instantiate_reduce_functions(name, int32, int32_t, int32_t, op) \
instantiate_reduce_functions(name, int64, int64_t, int64_t, op) \
instantiate_reduce_functions(name, float16, float16_t, float16_t, op) \
instantiate_reduce_functions(name, bfloat16, bfloat16_t, bfloat16_t, op) \
instantiate_reduce_functions(name, float32, float, float, op) \
instantiate_reduce_functions(name, complex64, complex64_t, complex64_t, op)
instantiate_sum_prod(sum, Sum)
instantiate_sum_prod(prod, Prod)
#define instantiate_min_max(name, op) \
instantiate_reduce_functions(name, int8, int8_t, int8_t, op) \
instantiate_reduce_functions(name, int16, int16_t, int16_t, op) \
instantiate_reduce_functions(name, int32, int32_t, int32_t, op) \
instantiate_reduce_functions(name, int64, int64_t, int64_t, op) \
instantiate_reduce_functions(name, uint8, uint8_t, uint8_t, op) \
instantiate_reduce_functions(name, uint16, uint16_t, uint16_t, op) \
instantiate_reduce_functions(name, uint32, uint32_t, uint32_t, op) \
instantiate_reduce_functions(name, uint64, uint64_t, uint64_t, op) \
instantiate_reduce_functions(name, float16, float16_t, float16_t, op) \
instantiate_reduce_functions(name, bfloat16, bfloat16_t, bfloat16_t, op) \
instantiate_reduce_functions(name, float32, float, float, op) \
instantiate_reduce_functions(name, complex64, complex64_t, complex64_t, op)
instantiate_min_max(min, Min)
instantiate_min_max(max, Max)
// clang-format on // clang-format on

View File

@ -1,6 +1,11 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS> template <
typename T,
typename U,
typename Op,
typename IdxT = int64_t,
int N_READS = REDUCE_N_READS>
[[kernel]] void all_reduce( [[kernel]] void all_reduce(
const device T* in [[buffer(0)]], const device T* in [[buffer(0)]],
device U* out [[buffer(1)]], device U* out [[buffer(1)]],
@ -16,10 +21,10 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
threadgroup U shared_vals[simd_size]; threadgroup U shared_vals[simd_size];
U total = Op::init; U total = Op::init;
int64_t start_idx = gid.y * row_size; IdxT start_idx = gid.y * IdxT(row_size);
int64_t actual_row = IdxT actual_row =
(start_idx + row_size <= in_size) ? row_size : in_size - start_idx; (start_idx + row_size <= in_size) ? row_size : in_size - start_idx;
int64_t blocks = actual_row / (lsize.x * N_READS); IdxT blocks = actual_row / (lsize.x * N_READS);
int extra = actual_row - blocks * (lsize.x * N_READS); int extra = actual_row - blocks * (lsize.x * N_READS);
extra -= lid.x * N_READS; extra -= lid.x * N_READS;
start_idx += lid.x * N_READS; start_idx += lid.x * N_READS;
@ -30,7 +35,7 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
extra = 0; extra = 0;
} }
for (int64_t b = 0; b < blocks; b++) { for (IdxT b = 0; b < blocks; b++) {
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
total = op(static_cast<U>(in[i]), total); total = op(static_cast<U>(in[i]), total);
} }

View File

@ -1,6 +1,6 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
template <typename T, typename U, typename Op, int NDIMS> template <typename T, typename U, typename Op, typename IdxT, int NDIMS>
[[kernel]] void col_reduce_small( [[kernel]] void col_reduce_small(
const device T* in [[buffer(0)]], const device T* in [[buffer(0)]],
device U* out [[buffer(1)]], device U* out [[buffer(1)]],
@ -19,7 +19,7 @@ template <typename T, typename U, typename Op, int NDIMS>
uint3 lsize [[threads_per_threadgroup]]) { uint3 lsize [[threads_per_threadgroup]]) {
constexpr int n_reads = 4; constexpr int n_reads = 4;
Op op; Op op;
looped_elem_to_loc<NDIMS> loop; LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
const device T* row; const device T* row;
U totals[n_reads]; U totals[n_reads];
@ -27,20 +27,20 @@ template <typename T, typename U, typename Op, int NDIMS>
totals[i] = Op::init; totals[i] = Op::init;
} }
size_t column = size_t(gid.x) * lsize.x * n_reads + lid.x * n_reads; IdxT column = IdxT(gid.x) * lsize.x * n_reads + lid.x * n_reads;
if (column >= reduction_stride) { if (column >= reduction_stride) {
return; return;
} }
bool safe = column + n_reads <= reduction_stride; bool safe = column + n_reads <= reduction_stride;
size_t out_idx = gid.y + gsize.y * size_t(gid.z); IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);
size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim); IdxT in_idx = elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim);
in += in_idx + column; in += in_idx + column;
size_t total_rows = non_col_reductions * reduction_size; IdxT total_rows = IdxT(non_col_reductions) * IdxT(reduction_size);
loop.next(lid.y, reduce_shape, reduce_strides); loop.next(lid.y, reduce_shape, reduce_strides);
for (size_t r = lid.y; r < total_rows; r += lsize.y) { for (IdxT r = lid.y; r < total_rows; r += lsize.y) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); row = in + loop.location();
if (safe) { if (safe) {
for (int i = 0; i < n_reads; i++) { for (int i = 0; i < n_reads; i++) {
totals[i] = op(static_cast<U>(row[i]), totals[i]); totals[i] = op(static_cast<U>(row[i]), totals[i]);
@ -80,7 +80,7 @@ template <typename T, typename U, typename Op, int NDIMS>
} }
if (lid.y == 0) { if (lid.y == 0) {
out += out_idx * reduction_stride + column; out += out_idx * IdxT(reduction_stride) + column;
if (safe) { if (safe) {
for (int i = 0; i < n_reads; i++) { for (int i = 0; i < n_reads; i++) {
out[i] = totals[i]; out[i] = totals[i];
@ -93,7 +93,7 @@ template <typename T, typename U, typename Op, int NDIMS>
} }
} }
template <typename T, typename U, typename Op, int NDIMS> template <typename T, typename U, typename Op, typename IdxT, int NDIMS>
[[kernel]] void col_reduce_longcolumn( [[kernel]] void col_reduce_longcolumn(
const device T* in [[buffer(0)]], const device T* in [[buffer(0)]],
device U* out [[buffer(1)]], device U* out [[buffer(1)]],
@ -112,19 +112,19 @@ template <typename T, typename U, typename Op, int NDIMS>
uint3 lid [[thread_position_in_threadgroup]], uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]]) { uint3 lsize [[threads_per_threadgroup]]) {
Op op; Op op;
looped_elem_to_loc<NDIMS> loop; LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
const device T* row; const device T* row;
size_t out_idx = gid.x + gsize.x * size_t(gid.y); IdxT out_idx = gid.x + gsize.x * IdxT(gid.y);
size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim); IdxT in_idx = elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim);
in += in_idx + lid.x; in += in_idx + lid.x;
U total = Op::init; U total = Op::init;
size_t total_rows = non_col_reductions * reduction_size; IdxT total_rows = IdxT(non_col_reductions) * IdxT(reduction_size);
loop.next(gid.z * lsize.y + lid.y, reduce_shape, reduce_strides); loop.next(gid.z * lsize.y + lid.y, reduce_shape, reduce_strides);
for (size_t r = gid.z * lsize.y + lid.y; r < total_rows; for (IdxT r = gid.z * lsize.y + lid.y; r < total_rows;
r += lsize.y * gsize.z) { r += lsize.y * gsize.z) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); row = in + loop.location();
total = op(static_cast<U>(*row), total); total = op(static_cast<U>(*row), total);
loop.next(lsize.y * gsize.z, reduce_shape, reduce_strides); loop.next(lsize.y * gsize.z, reduce_shape, reduce_strides);
} }
@ -136,7 +136,8 @@ template <typename T, typename U, typename Op, int NDIMS>
for (uint i = 1; i < lsize.y; i++) { for (uint i = 1; i < lsize.y; i++) {
total = op(total, shared_vals[i * lsize.x + lid.x]); total = op(total, shared_vals[i * lsize.x + lid.x]);
} }
out[gid.z * out_size + out_idx * reduction_stride + lid.x] = total; out[gid.z * IdxT(out_size) + out_idx * IdxT(reduction_stride) + lid.x] =
total;
} }
} }
@ -151,7 +152,14 @@ template <typename T, typename U, typename Op, int NDIMS>
* totals with a loop. * totals with a loop.
* 7. Write them to the output * 7. Write them to the output
*/ */
template <typename T, typename U, typename Op, int NDIMS, int BM, int BN> template <
typename T,
typename U,
typename Op,
typename IdxT,
int NDIMS,
int BM,
int BN>
[[kernel]] void col_reduce_looped( [[kernel]] void col_reduce_looped(
const device T* in [[buffer(0)]], const device T* in [[buffer(0)]],
device U* out [[buffer(1)]], device U* out [[buffer(1)]],
@ -176,7 +184,7 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
threadgroup U shared_vals[BN * BM]; threadgroup U shared_vals[BN * BM];
U totals[n_reads]; U totals[n_reads];
looped_elem_to_loc<NDIMS> loop; LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
const device T* row; const device T* row;
for (int i = 0; i < n_reads; i++) { for (int i = 0; i < n_reads; i++) {
@ -185,17 +193,17 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
short lid = simd_group_id * simd_size + simd_lane_id; short lid = simd_group_id * simd_size + simd_lane_id;
short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks); short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);
size_t column = BN * gid.x + offset.x; IdxT column = BN * gid.x + offset.x;
bool safe = column + n_reads <= reduction_stride; bool safe = column + n_reads <= reduction_stride;
size_t out_idx = gid.y + gsize.y * size_t(gid.z); IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);
size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim); IdxT in_idx = elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim);
in += in_idx + column; in += in_idx + column;
size_t total = non_col_reductions * reduction_size; IdxT total = IdxT(non_col_reductions) * IdxT(reduction_size);
loop.next(offset.y, reduce_shape, reduce_strides); loop.next(offset.y, reduce_shape, reduce_strides);
for (size_t r = offset.y; r < total; r += BM) { for (IdxT r = offset.y; r < total; r += BM) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); row = in + loop.location();
if (safe) { if (safe) {
for (int i = 0; i < n_reads; i++) { for (int i = 0; i < n_reads; i++) {
@ -235,8 +243,8 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
// Write the output. // Write the output.
if (simd_lane_id == 0) { if (simd_lane_id == 0) {
size_t out_column = BN * gid.x + out_offset.x; IdxT out_column = BN * gid.x + out_offset.x;
out += out_idx * reduction_stride + out_column; out += out_idx * IdxT(reduction_stride) + out_column;
if (out_column + n_outputs <= reduction_stride) { if (out_column + n_outputs <= reduction_stride) {
for (int i = 0; i < n_outputs; i++) { for (int i = 0; i < n_outputs; i++) {
out[i] = totals[i]; out[i] = totals[i];
@ -269,7 +277,7 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
// Write the output. // Write the output.
if (offset.y == 0) { if (offset.y == 0) {
out += out_idx * reduction_stride + column; out += out_idx * IdxT(reduction_stride) + column;
if (safe) { if (safe) {
for (int i = 0; i < n_reads; i++) { for (int i = 0; i < n_reads; i++) {
out[i] = totals[i]; out[i] = totals[i];
@ -283,7 +291,14 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
} }
} }
template <typename T, typename U, typename Op, int NDIMS, int BM, int BN> template <
typename T,
typename U,
typename Op,
typename IdxT,
int NDIMS,
int BM,
int BN>
[[kernel]] void col_reduce_2pass( [[kernel]] void col_reduce_2pass(
const device T* in [[buffer(0)]], const device T* in [[buffer(0)]],
device U* out [[buffer(1)]], device U* out [[buffer(1)]],
@ -312,7 +327,7 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
threadgroup U shared_vals[BN * BM]; threadgroup U shared_vals[BN * BM];
U totals[n_reads]; U totals[n_reads];
looped_elem_to_loc<NDIMS> loop; LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
const device T* row; const device T* row;
for (int i = 0; i < n_reads; i++) { for (int i = 0; i < n_reads; i++) {
@ -321,20 +336,19 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
short lid = simd_group_id * simd_size + simd_lane_id; short lid = simd_group_id * simd_size + simd_lane_id;
short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks); short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);
size_t column = BN * gid.x + offset.x; IdxT column = BN * gid.x + offset.x;
bool safe = column + n_reads <= reduction_stride; bool safe = column + n_reads <= reduction_stride;
size_t full_idx = gid.y + gsize.y * size_t(gid.z); IdxT full_idx = gid.y + gsize.y * IdxT(gid.z);
size_t block_idx = full_idx / out_size; IdxT block_idx = full_idx / IdxT(out_size);
size_t out_idx = full_idx % out_size; IdxT out_idx = full_idx % IdxT(out_size);
size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim); IdxT in_idx = elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim);
in += in_idx + column; in += in_idx + column;
size_t total = non_col_reductions * reduction_size; IdxT total = IdxT(non_col_reductions) * IdxT(reduction_size);
loop.next(offset.y + block_idx * BM, reduce_shape, reduce_strides); loop.next(offset.y + block_idx * BM, reduce_shape, reduce_strides);
for (size_t r = offset.y + block_idx * BM; r < total; for (IdxT r = offset.y + block_idx * BM; r < total; r += outer_blocks * BM) {
r += outer_blocks * BM) { row = in + loop.location();
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
if (safe) { if (safe) {
for (int i = 0; i < n_reads; i++) { for (int i = 0; i < n_reads; i++) {
@ -369,8 +383,8 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
// Write the output. // Write the output.
if (simd_lane_id == 0) { if (simd_lane_id == 0) {
size_t out_column = BN * gid.x + out_offset.x; IdxT out_column = BN * gid.x + out_offset.x;
out += full_idx * reduction_stride + out_column; out += full_idx * IdxT(reduction_stride) + out_column;
if (out_column + n_outputs <= reduction_stride) { if (out_column + n_outputs <= reduction_stride) {
for (int i = 0; i < n_outputs; i++) { for (int i = 0; i < n_outputs; i++) {
out[i] = totals[i]; out[i] = totals[i];

View File

@ -193,6 +193,7 @@ template <
typename T, typename T,
typename U, typename U,
typename Op, typename Op,
typename IdxT,
int NDIMS, int NDIMS,
int N_READS = REDUCE_N_READS> int N_READS = REDUCE_N_READS>
[[kernel]] void row_reduce_small( [[kernel]] void row_reduce_small(
@ -214,20 +215,20 @@ template <
Op op; Op op;
U total_val = Op::init; U total_val = Op::init;
looped_elem_to_loc<NDIMS> loop; LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
// Precompute some row reduction numbers // Precompute some row reduction numbers
const device T* row; const device T* row;
int blocks = row_size / N_READS; int blocks = IdxT(row_size) / N_READS;
int extra = row_size % N_READS; int extra = IdxT(row_size) % N_READS;
if ((non_row_reductions < 32 && row_size <= 8) || non_row_reductions <= 8) { if ((non_row_reductions < 32 && row_size <= 8) || non_row_reductions <= 8) {
// Simple loop over non_row_reductions and reduce the row in the thread. // Simple loop over non_row_reductions and reduce the row in the thread.
size_t out_idx = tid.x + tsize.y * size_t(tid.y); IdxT out_idx = tid.x + tsize.y * IdxT(tid.y);
in += elem_to_loc(out_idx, shape, strides, ndim); in += elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim);
for (uint r = 0; r < non_row_reductions; r++) { for (uint r = 0; r < non_row_reductions; r++) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); row = in + loop.location();
thread_reduce<T, U, Op, N_READS>(total_val, row, blocks, extra); thread_reduce<T, U, Op, N_READS>(total_val, row, blocks, extra);
loop.next(reduce_shape, reduce_strides); loop.next(reduce_shape, reduce_strides);
} }
@ -236,13 +237,13 @@ template <
} else { } else {
// Collaboratively reduce over non_row_reductions in the simdgroup. Each // Collaboratively reduce over non_row_reductions in the simdgroup. Each
// thread reduces every 32nd row and then a simple simd reduce. // thread reduces every 32nd row and then a simple simd reduce.
size_t out_idx = gid.y + gsize.y * size_t(gid.z); IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);
in += elem_to_loc(out_idx, shape, strides, ndim); in += elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim);
loop.next(simd_lane_id, reduce_shape, reduce_strides); loop.next(simd_lane_id, reduce_shape, reduce_strides);
for (uint r = simd_lane_id; r < non_row_reductions; r += simd_size) { for (uint r = simd_lane_id; r < non_row_reductions; r += simd_size) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); row = in + loop.location();
thread_reduce<T, U, Op, N_READS>(total_val, row, blocks, extra); thread_reduce<T, U, Op, N_READS>(total_val, row, blocks, extra);
loop.next(simd_size, reduce_shape, reduce_strides); loop.next(simd_size, reduce_shape, reduce_strides);
} }
@ -259,6 +260,7 @@ template <
typename T, typename T,
typename U, typename U,
typename Op, typename Op,
typename IdxT = size_t,
int N_READS = REDUCE_N_READS, int N_READS = REDUCE_N_READS,
int N_WRITES = REDUCE_N_WRITES> int N_WRITES = REDUCE_N_WRITES>
[[kernel]] void row_reduce_simple( [[kernel]] void row_reduce_simple(
@ -277,15 +279,15 @@ template <
U totals[N_WRITES]; U totals[N_WRITES];
// Move to the row // Move to the row
size_t out_idx = N_WRITES * (gid.y + gsize.y * size_t(gid.z)); IdxT out_idx = N_WRITES * (gid.y + gsize.y * IdxT(gid.z));
if (out_idx + N_WRITES > out_size) { if (out_idx + N_WRITES > out_size) {
out_idx = out_size - N_WRITES; out_idx = out_size - N_WRITES;
} }
in += out_idx * reduction_size; in += out_idx * IdxT(reduction_size);
out += out_idx; out += out_idx;
// Each thread reduces across the row // Each thread reduces across the row
int blocks = reduction_size / (lsize.x * N_READS); int blocks = IdxT(reduction_size) / (lsize.x * N_READS);
int extra = reduction_size - blocks * (lsize.x * N_READS); int extra = reduction_size - blocks * (lsize.x * N_READS);
per_thread_row_reduce<T, U, Op, N_READS, N_WRITES>( per_thread_row_reduce<T, U, Op, N_READS, N_WRITES>(
totals, in, reduction_size, blocks, extra, lsize.x, lid.x); totals, in, reduction_size, blocks, extra, lsize.x, lid.x);
@ -306,6 +308,7 @@ template <
typename T, typename T,
typename U, typename U,
typename Op, typename Op,
typename IdxT,
int NDIMS, int NDIMS,
int N_READS = REDUCE_N_READS> int N_READS = REDUCE_N_READS>
[[kernel]] void row_reduce_looped( [[kernel]] void row_reduce_looped(
@ -330,19 +333,20 @@ template <
threadgroup U shared_vals[simd_size]; threadgroup U shared_vals[simd_size];
U total = Op::init; U total = Op::init;
size_t out_idx = gid.y + gsize.y * size_t(gid.z); IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);
// lid.x * N_READS breaks the per_thread_row_reduce interface a bit. Maybe it // lid.x * N_READS breaks the per_thread_row_reduce interface a bit. Maybe it
// needs a small refactor. // needs a small refactor.
in += elem_to_loc(out_idx, shape, strides, ndim) + lid.x * N_READS; in += elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim) +
lid.x * N_READS;
looped_elem_to_loc<NDIMS> loop; LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
const device T* row; const device T* row;
int blocks = row_size / (lsize.x * N_READS); int blocks = IdxT(row_size) / (lsize.x * N_READS);
int extra = row_size - blocks * (lsize.x * N_READS); int extra = row_size - blocks * (lsize.x * N_READS);
for (size_t i = 0; i < non_row_reductions; i++) { for (IdxT i = 0; i < non_row_reductions; i++) {
row = in + loop.location(i, reduce_shape, reduce_strides, reduce_ndim); row = in + loop.location();
// Each thread reduces across the row // Each thread reduces across the row
U row_total; U row_total;

View File

@ -204,16 +204,21 @@ METAL_FUNC vec<IdxT, 3> elem_to_loc_3_nd(
// Elem to loc in a loop utils // Elem to loc in a loop utils
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
template <int dim, typename offset_t = size_t> template <int DIM, typename OffsetT = size_t, bool General = true>
struct looped_elem_to_loc { struct LoopedElemToLoc {
looped_elem_to_loc<dim - 1, offset_t> inner_looper; int dim;
offset_t offset{0}; LoopedElemToLoc<DIM - 1, OffsetT, General> inner_looper;
OffsetT offset{0};
int index{0}; int index{0};
void next(const constant int* shape, const constant size_t* strides) { LoopedElemToLoc(int dim) : dim(dim), inner_looper(dim - 1) {}
index++;
offset += strides[dim - 1];
void next(const constant int* shape, const constant size_t* strides) {
if (dim == 0) {
return;
}
index++;
offset += OffsetT(strides[dim - 1]);
if (index >= shape[dim - 1]) { if (index >= shape[dim - 1]) {
index = 0; index = 0;
inner_looper.next(shape, strides); inner_looper.next(shape, strides);
@ -222,13 +227,21 @@ struct looped_elem_to_loc {
} }
void next(int n, const constant int* shape, const constant size_t* strides) { void next(int n, const constant int* shape, const constant size_t* strides) {
if (dim == 0) {
return;
}
index += n; index += n;
offset += n * strides[dim - 1]; offset += n * OffsetT(strides[dim - 1]);
if (index >= shape[dim - 1]) { if (index >= shape[dim - 1]) {
int extra = index - shape[dim - 1]; int extra = index - shape[dim - 1];
if (extra >= shape[dim - 1]) {
inner_looper.next(1 + extra / shape[dim - 1], shape, strides);
extra = extra % shape[dim - 1];
} else {
inner_looper.next(shape, strides);
}
index = 0; index = 0;
inner_looper.next(shape, strides);
offset = inner_looper.offset; offset = inner_looper.offset;
if (extra > 0) { if (extra > 0) {
next(extra, shape, strides); next(extra, shape, strides);
@ -236,44 +249,61 @@ struct looped_elem_to_loc {
} }
} }
offset_t OffsetT location() {
location(offset_t, const constant int*, const constant size_t*, int) {
return offset; return offset;
} }
}; };
template <typename offset_t> template <typename OffsetT>
struct looped_elem_to_loc<1, offset_t> { struct LoopedElemToLoc<1, OffsetT, true> {
offset_t offset{0}; int dim;
OffsetT offset{0};
uint index{0};
LoopedElemToLoc(int dim) : dim(dim) {}
void next(const constant int* shape, const constant size_t* strides) {
index++;
if (dim > 1) {
offset = elem_to_loc<size_t, OffsetT>(index, shape, strides, dim);
} else {
offset += OffsetT(strides[0]);
}
}
void next(int n, const constant int* shape, const constant size_t* strides) {
index += n;
if (dim > 1) {
offset = elem_to_loc<size_t, OffsetT>(index, shape, strides, dim);
} else {
offset = index * OffsetT(strides[0]);
}
}
OffsetT location() {
return offset;
}
};
template <typename OffsetT>
struct LoopedElemToLoc<1, OffsetT, false> {
OffsetT offset{0};
LoopedElemToLoc(int) {}
void next(const constant int*, const constant size_t* strides) { void next(const constant int*, const constant size_t* strides) {
offset += strides[0]; offset += OffsetT(strides[0]);
} }
void next(int n, const constant int*, const constant size_t* strides) { void next(int n, const constant int*, const constant size_t* strides) {
offset += n * strides[0]; offset += n * OffsetT(strides[0]);
} }
offset_t OffsetT location() {
location(offset_t, const constant int*, const constant size_t*, int) {
return offset; return offset;
} }
}; };
template <typename offset_t>
struct looped_elem_to_loc<0, offset_t> {
void next(const constant int*, const constant size_t*) {}
void next(int, const constant int*, const constant size_t*) {}
offset_t location(
offset_t idx,
const constant int* shape,
const constant size_t* strides,
int ndim) {
return elem_to_loc(idx, shape, strides, ndim);
}
};
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Calculation utils // Calculation utils
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////

View File

@ -1,3 +1,4 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"
@ -99,7 +100,7 @@ MTL::ComputePipelineState* get_reduce_init_kernel(
const std::string& kernel_name, const std::string& kernel_name,
const std::string&, const std::string&,
const std::string&, const std::string&,
const array&) { const Dtype&) {
return d.get_kernel(kernel_name); return d.get_kernel(kernel_name);
} }
@ -108,8 +109,9 @@ MTL::ComputePipelineState* get_reduce_kernel(
const std::string& kernel_name, const std::string& kernel_name,
const std::string&, const std::string&,
const std::string&, const std::string&,
const array&, const Dtype&,
const array&, const Dtype&,
const std::string&,
int, int,
int, int,
int) { int) {

View File

@ -2,7 +2,6 @@
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
#include <sstream>
#include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
@ -202,6 +201,16 @@ inline bool is_64b_dtype(Dtype dtype) {
return dtype == int64 || dtype == uint64 || dtype == complex64; return dtype == int64 || dtype == uint64 || dtype == complex64;
} }
inline int get_kernel_reduce_ndim(int reduce_ndim) {
if (reduce_ndim <= 1) {
return 1;
} else if (reduce_ndim == 2) {
return 2;
} else {
return 5;
}
}
inline int threadgroup_size_from_row_size(int row_size) { inline int threadgroup_size_from_row_size(int row_size) {
// 1 simdgroup per row smallish rows // 1 simdgroup per row smallish rows
if (row_size <= 512) { if (row_size <= 512) {
@ -233,16 +242,51 @@ inline auto output_grid_for_col_reduce(
return get_2d_grid_dims(out_shape, out_strides); return get_2d_grid_dims(out_shape, out_strides);
} }
std::pair<Dtype, Dtype> remap_reduce_types(
const array& in,
const std::string& op_name) {
if (op_name == "sum" || op_name == "prod") {
if (issubdtype(in.dtype(), integer)) {
switch (in.dtype().size()) {
case 1:
return {int8, int32};
case 2:
return {int16, int32};
case 4:
return {int32, int32};
case 8:
return {int64, int64};
}
}
if (in.dtype() == bool_) {
return {int8, int32};
}
return {in.dtype(), in.dtype()};
} else if (op_name == "and" || op_name == "or") {
if (in.dtype().size() == 1) {
return {bool_, bool_};
} else if (in.dtype().size() == 2) {
return {int16, bool_};
} else if (in.dtype().size() == 4) {
return {int32, bool_};
} else {
return {int64, bool_};
}
}
return {in.dtype(), in.dtype()};
}
void init_reduce( void init_reduce(
array& out, array& out,
const std::string& op_name, const std::string& op_name,
CommandEncoder& compute_encoder, CommandEncoder& compute_encoder,
metal::Device& d, metal::Device& d,
const Stream& s) { const Stream& s) {
std::ostringstream kname; auto [_, out_type] = remap_reduce_types(out, op_name);
const std::string func_name = "init_reduce"; const std::string func_name = "init_reduce";
kname << func_name << "_" << op_name << type_to_name(out); std::string kname = func_name;
auto kernel = get_reduce_init_kernel(d, kname.str(), func_name, op_name, out); concatenate(kname, "_", op_name, type_to_name(out_type));
auto kernel = get_reduce_init_kernel(d, kname, func_name, op_name, out_type);
size_t nthreads = out.size(); size_t nthreads = out.size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
@ -263,10 +307,12 @@ void all_reduce_dispatch(
metal::Device& d, metal::Device& d,
const Stream& s) { const Stream& s) {
// Set the kernel // Set the kernel
std::ostringstream kname; auto [in_type, out_type] = remap_reduce_types(in, op_name);
const std::string func_name = "all_reduce"; const std::string func_name = "all_reduce";
kname << func_name << "_" << op_name << type_to_name(in); std::string kname = func_name;
auto kernel = get_reduce_kernel(d, kname.str(), func_name, op_name, in, out); concatenate(kname, "_", op_name, type_to_name(in_type));
auto kernel = get_reduce_kernel(
d, kname, func_name, op_name, in_type, out_type, "int64_t");
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
size_t in_size = in.size(); size_t in_size = in.size();
@ -300,7 +346,7 @@ void all_reduce_dispatch(
} }
// Allocate an intermediate tensor to hold results if needed // Allocate an intermediate tensor to hold results if needed
array intermediate({n_rows}, out.dtype(), nullptr, {}); array intermediate({n_rows}, out_type, nullptr, {});
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes())); intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
d.add_temporary(intermediate, s.index); d.add_temporary(intermediate, s.index);
@ -318,10 +364,10 @@ void all_reduce_dispatch(
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
// 2nd pass // 2nd pass
std::ostringstream kname_2nd_pass; std::string kname_2nd_pass = func_name;
kname_2nd_pass << func_name << "_" << op_name << type_to_name(intermediate); concatenate(kname_2nd_pass, "_", op_name, type_to_name(intermediate));
auto kernel_2nd_pass = get_reduce_kernel( auto kernel_2nd_pass = get_reduce_kernel(
d, kname_2nd_pass.str(), func_name, op_name, intermediate, out); d, kname_2nd_pass, func_name, op_name, out_type, out_type, "int64_t");
compute_encoder.set_compute_pipeline_state(kernel_2nd_pass); compute_encoder.set_compute_pipeline_state(kernel_2nd_pass);
size_t intermediate_size = n_rows; size_t intermediate_size = n_rows;
grid_dims = MTL::Size(threadgroup_2nd_pass, 1, 1); grid_dims = MTL::Size(threadgroup_2nd_pass, 1, 1);
@ -343,12 +389,30 @@ void row_reduce_small(
metal::Device& d, metal::Device& d,
const Stream& s) { const Stream& s) {
// Set the kernel // Set the kernel
std::ostringstream kname; int n = get_kernel_reduce_ndim(args.reduce_ndim);
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0; auto [in_type, out_type] = remap_reduce_types(in, op_name);
const std::string func_name = "row_reduce_small"; const std::string func_name = "row_reduce_small";
kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in); std::string kname = func_name;
auto kernel = bool large = in.size() > UINT32_MAX;
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n); if (large) {
kname += "_large";
}
concatenate(
kname,
"_",
std::to_string(n),
"_reduce_",
op_name,
type_to_name(in_type));
auto kernel = get_reduce_kernel(
d,
kname,
func_name,
op_name,
in_type,
out_type,
large ? "size_t" : "uint",
n);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// Figure out the grid dims // Figure out the grid dims
@ -381,10 +445,13 @@ void row_reduce_simple(
metal::Device& d, metal::Device& d,
const Stream& s) { const Stream& s) {
// Set the kernel // Set the kernel
std::ostringstream kname; auto [in_type, out_type] = remap_reduce_types(in, op_name);
const std::string func_name = "row_reduce_simple"; const std::string func_name = "row_reduce_simple";
kname << func_name << "_" << op_name << type_to_name(in); std::string kname = func_name;
auto kernel = get_reduce_kernel(d, kname.str(), func_name, op_name, in, out); concatenate(kname, "_", op_name, type_to_name(in_type));
auto kernel = get_reduce_kernel(
d, kname, func_name, op_name, in_type, out_type, "size_t");
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// Figure out the grid dims // Figure out the grid dims
@ -417,13 +484,32 @@ void row_reduce_looped(
CommandEncoder& compute_encoder, CommandEncoder& compute_encoder,
metal::Device& d, metal::Device& d,
const Stream& s) { const Stream& s) {
auto [in_type, out_type] = remap_reduce_types(in, op_name);
// Set the kernel // Set the kernel
std::ostringstream kname; int n = get_kernel_reduce_ndim(args.reduce_ndim);
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
const std::string func_name = "row_reduce_looped"; const std::string func_name = "row_reduce_looped";
kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in); std::string kname = func_name;
auto kernel = bool large = in.size() > UINT32_MAX;
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n); if (large) {
kname += "_large";
}
concatenate(
kname,
"_",
std::to_string(n),
"_reduce_",
op_name,
type_to_name(in_type));
auto kernel = get_reduce_kernel(
d,
kname,
func_name,
op_name,
in_type,
out_type,
large ? "size_t" : "uint",
n);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// Figure out the grid // Figure out the grid
@ -475,6 +561,8 @@ void strided_reduce_small(
CommandEncoder& compute_encoder, CommandEncoder& compute_encoder,
metal::Device& d, metal::Device& d,
const Stream& s) { const Stream& s) {
auto [in_type, out_type] = remap_reduce_types(in, op_name);
// Figure out the grid dims // Figure out the grid dims
MTL::Size grid_dims, group_dims; MTL::Size grid_dims, group_dims;
@ -483,12 +571,29 @@ void strided_reduce_small(
args.reduce_strides.push_back(args.reduction_stride); args.reduce_strides.push_back(args.reduction_stride);
args.reduce_ndim++; args.reduce_ndim++;
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0; int n = get_kernel_reduce_ndim(args.reduce_ndim);
std::ostringstream kname;
const std::string func_name = "col_reduce_small"; const std::string func_name = "col_reduce_small";
kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in); std::string kname = func_name;
auto kernel = bool large = in.size() > UINT32_MAX;
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n); if (large) {
kname += "_large";
}
concatenate(
kname,
"_",
std::to_string(n),
"_reduce_",
op_name,
type_to_name(in_type));
auto kernel = get_reduce_kernel(
d,
kname,
func_name,
op_name,
in_type,
out_type,
large ? "size_t" : "uint",
n);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
const int n_reads = 4; const int n_reads = 4;
@ -522,6 +627,7 @@ void strided_reduce_longcolumn(
CommandEncoder& compute_encoder, CommandEncoder& compute_encoder,
metal::Device& d, metal::Device& d,
const Stream& s) { const Stream& s) {
auto [in_type, out_type] = remap_reduce_types(in, op_name);
size_t total_reduction_size = args.reduction_size * args.non_col_reductions; size_t total_reduction_size = args.reduction_size * args.non_col_reductions;
size_t outer_blocks = 32; size_t outer_blocks = 32;
if (total_reduction_size >= 32768) { if (total_reduction_size >= 32768) {
@ -534,7 +640,7 @@ void strided_reduce_longcolumn(
intermediate_shape.push_back(outer_blocks); intermediate_shape.push_back(outer_blocks);
intermediate_shape.insert( intermediate_shape.insert(
intermediate_shape.end(), out.shape().begin(), out.shape().end()); intermediate_shape.end(), out.shape().begin(), out.shape().end());
array intermediate(std::move(intermediate_shape), out.dtype(), nullptr, {}); array intermediate(std::move(intermediate_shape), out_type, nullptr, {});
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes())); intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
d.add_temporary(intermediate, s.index); d.add_temporary(intermediate, s.index);
@ -556,12 +662,29 @@ void strided_reduce_longcolumn(
MTL::Size group_dims(threadgroup_x, threadgroup_y, 1); MTL::Size group_dims(threadgroup_x, threadgroup_y, 1);
// Set the kernel // Set the kernel
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0; int n = get_kernel_reduce_ndim(args.reduce_ndim);
std::ostringstream kname; std::string func_name = "col_reduce_longcolumn";
const std::string func_name = "col_reduce_longcolumn"; std::string kname = func_name;
kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in); bool large = in.size() > UINT32_MAX;
auto kernel = if (large) {
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n); kname += "_large";
}
concatenate(
kname,
"_",
std::to_string(n),
"_reduce_",
op_name,
type_to_name(in_type));
auto kernel = get_reduce_kernel(
d,
kname,
func_name,
op_name,
in_type,
out_type,
large ? "size_t" : "uint",
n);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// Launch // Launch
@ -581,15 +704,21 @@ void strided_reduce_longcolumn(
group_dims = MTL::Size(256, 1, 1); group_dims = MTL::Size(256, 1, 1);
// Set the 2nd kernel // Set the 2nd kernel
const std::string second_kernel = "col_reduce_looped_1_32_32_reduce_" + func_name = "col_reduce_looped";
op_name + type_to_name(intermediate); kname = func_name;
large = intermediate.size() > UINT32_MAX;
if (large) {
kname += "_large";
}
concatenate(kname, "_1_32_32_reduce_", op_name, type_to_name(intermediate));
kernel = get_reduce_kernel( kernel = get_reduce_kernel(
d, d,
second_kernel, kname,
"col_reduce_looped", func_name,
op_name, op_name,
intermediate, intermediate.dtype(),
out, out_type,
large ? "size_t" : "uint",
1, 1,
32, 32,
32); 32);
@ -609,6 +738,8 @@ void strided_reduce_looped(
CommandEncoder& compute_encoder, CommandEncoder& compute_encoder,
metal::Device& d, metal::Device& d,
const Stream& s) { const Stream& s) {
auto [in_type, out_type] = remap_reduce_types(in, op_name);
// Prepare the arguments for the kernel // Prepare the arguments for the kernel
args.reduce_shape.push_back(args.reduction_size); args.reduce_shape.push_back(args.reduction_size);
args.reduce_strides.push_back(args.reduction_stride); args.reduce_strides.push_back(args.reduction_stride);
@ -626,13 +757,35 @@ void strided_reduce_looped(
MTL::Size group_dims(threadgroup_size, 1, 1); MTL::Size group_dims(threadgroup_size, 1, 1);
// Set the kernel // Set the kernel
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0; int n = get_kernel_reduce_ndim(args.reduce_ndim);
std::ostringstream kname; std::string func_name = "col_reduce_looped";
const std::string func_name = "col_reduce_looped"; std::string kname = func_name;
kname << func_name << "_" << n << "_" << BM << "_" << BN << "_reduce_" bool large = in.size() > UINT32_MAX;
<< op_name << type_to_name(in); if (large) {
auto kernel = kname += "_large";
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n, BM, BN); }
concatenate(
kname,
"_",
std::to_string(n),
"_",
std::to_string(BM),
"_",
std::to_string(BN),
"_reduce_",
op_name,
type_to_name(in_type));
auto kernel = get_reduce_kernel(
d,
kname,
func_name,
op_name,
in_type,
out_type,
large ? "size_t" : "uint",
n,
BM,
BN);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// Launch // Launch
@ -650,13 +803,15 @@ void strided_reduce_2pass(
CommandEncoder& compute_encoder, CommandEncoder& compute_encoder,
metal::Device& d, metal::Device& d,
const Stream& s) { const Stream& s) {
auto [in_type, out_type] = remap_reduce_types(in, op_name);
// Prepare the temporary accumulator // Prepare the temporary accumulator
std::vector<int> intermediate_shape; std::vector<int> intermediate_shape;
intermediate_shape.reserve(out.ndim() + 1); intermediate_shape.reserve(out.ndim() + 1);
intermediate_shape.push_back(32); intermediate_shape.push_back(32);
intermediate_shape.insert( intermediate_shape.insert(
intermediate_shape.end(), out.shape().begin(), out.shape().end()); intermediate_shape.end(), out.shape().begin(), out.shape().end());
array intermediate(std::move(intermediate_shape), out.dtype(), nullptr, {}); array intermediate(std::move(intermediate_shape), out_type, nullptr, {});
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes())); intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
d.add_temporary(intermediate, s.index); d.add_temporary(intermediate, s.index);
@ -679,13 +834,35 @@ void strided_reduce_2pass(
MTL::Size group_dims(threadgroup_size, 1, 1); MTL::Size group_dims(threadgroup_size, 1, 1);
// Set the kernel // Set the kernel
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0; int n = get_kernel_reduce_ndim(args.reduce_ndim);
std::ostringstream kname; std::string func_name = "col_reduce_2pass";
const std::string func_name = "col_reduce_2pass"; std::string kname = func_name;
kname << func_name << "_" << n << "_" << BM << "_" << BN << "_reduce_" bool large = in.size() > UINT32_MAX;
<< op_name << type_to_name(in); if (large) {
auto kernel = kname += "_large";
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n, BM, BN); }
concatenate(
kname,
"_",
std::to_string(n),
"_",
std::to_string(BM),
"_",
std::to_string(BN),
"_reduce_",
op_name,
type_to_name(in_type));
auto kernel = get_reduce_kernel(
d,
kname,
func_name,
op_name,
in_type,
out_type,
large ? "size_t" : "uint",
n,
BM,
BN);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// Launch // Launch
@ -703,15 +880,21 @@ void strided_reduce_2pass(
grid_dims = MTL::Size(threadgroup_size * ((out.size() + BN - 1) / BN), 1, 1); grid_dims = MTL::Size(threadgroup_size * ((out.size() + BN - 1) / BN), 1, 1);
// Set the 2nd kernel // Set the 2nd kernel
const std::string second_kernel = "col_reduce_looped_1_32_32_reduce_" + func_name = "col_reduce_looped";
op_name + type_to_name(intermediate); kname = func_name;
large = intermediate.size() > UINT32_MAX;
if (large) {
kname += "_large";
}
concatenate(kname, "_1_32_32_reduce_", op_name, type_to_name(intermediate));
kernel = get_reduce_kernel( kernel = get_reduce_kernel(
d, d,
second_kernel, kname,
"col_reduce_looped", func_name,
op_name, op_name,
intermediate, intermediate.dtype(),
out, out_type,
large ? "size_t" : "uint",
1, 1,
32, 32,
32); 32);
@ -780,7 +963,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
op_name = "sum"; op_name = "sum";
break; break;
case Reduce::Prod: case Reduce::Prod:
op_name = out.dtype() == bool_ ? "and" : "prod"; op_name = "prod";
break; break;
case Reduce::Min: case Reduce::Min:
op_name = out.dtype() == bool_ ? "and" : "min"; op_name = out.dtype() == bool_ ? "and" : "min";

View File

@ -6,9 +6,9 @@ using namespace mlx;
namespace mlx::core { namespace mlx::core {
std::string type_to_name(const array& a) { std::string type_to_name(const Dtype& t) {
std::string tname; std::string tname;
switch (a.dtype()) { switch (t) {
case bool_: case bool_:
tname = "bool_"; tname = "bool_";
break; break;
@ -52,6 +52,10 @@ std::string type_to_name(const array& a) {
return tname; return tname;
} }
std::string type_to_name(const array& a) {
return type_to_name(a.dtype());
}
MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2 /* = 10 */) { MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2 /* = 10 */) {
int pows[3] = {0, 0, 0}; int pows[3] = {0, 0, 0};
int sum = 0; int sum = 0;

View File

@ -8,6 +8,7 @@
namespace mlx::core { namespace mlx::core {
std::string type_to_name(const Dtype& t);
std::string type_to_name(const array& a); std::string type_to_name(const array& a);
// Compute the thread block dimensions which fit the given // Compute the thread block dimensions which fit the given

View File

@ -1615,7 +1615,14 @@ array sum(
} }
auto [out_shape, sorted_axes, squeezed_shape, is_noop] = auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
compute_reduce_shape(axes, a.shape()); compute_reduce_shape(axes, a.shape());
auto out_type = a.dtype() == bool_ ? int32 : a.dtype(); Dtype out_type = a.dtype();
if (issubdtype(a.dtype(), signedinteger)) {
out_type = a.dtype().size() <= 4 ? int32 : int64;
} else if (issubdtype(a.dtype(), unsignedinteger)) {
out_type = a.dtype().size() <= 4 ? uint32 : uint64;
} else if (a.dtype() == bool_) {
out_type = int32;
}
auto out = (is_noop) auto out = (is_noop)
? astype(a, out_type, s) ? astype(a, out_type, s)
: array( : array(
@ -1760,11 +1767,19 @@ array prod(
} }
auto [out_shape, sorted_axes, squeezed_shape, is_noop] = auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
compute_reduce_shape(axes, a.shape()); compute_reduce_shape(axes, a.shape());
Dtype out_type = a.dtype();
if (issubdtype(a.dtype(), signedinteger)) {
out_type = a.dtype().size() <= 4 ? int32 : int64;
} else if (issubdtype(a.dtype(), unsignedinteger)) {
out_type = a.dtype().size() <= 4 ? uint32 : uint64;
} else if (a.dtype() == bool_) {
out_type = int32;
}
auto out = (is_noop) auto out = (is_noop)
? a ? a
: array( : array(
std::move(out_shape), std::move(out_shape),
a.dtype(), out_type,
std::make_shared<Reduce>(to_stream(s), Reduce::Prod, sorted_axes), std::make_shared<Reduce>(to_stream(s), Reduce::Prod, sorted_axes),
{a}); {a});
if (!keepdims) { if (!keepdims) {

View File

@ -131,6 +131,28 @@ class TestReduce(mlx_tests.MLXTestCase):
mxsum = y.sum().item() mxsum = y.sum().item()
self.assertEqual(npsum, mxsum) self.assertEqual(npsum, mxsum)
def test_many_reduction_axes(self):
def check(x, axes):
expected = x
for ax in axes:
expected = mx.sum(expected, axis=ax, keepdims=True)
out = mx.sum(x, axis=axes, keepdims=True)
self.assertTrue(mx.array_equal(out, expected))
x = mx.random.randint(0, 10, shape=(4, 4, 4, 4, 4))
check(x, (0, 2, 4))
x = mx.random.randint(0, 10, shape=(4, 4, 4, 4, 4, 4, 4))
check(x, (0, 2, 4, 6))
x = mx.random.randint(0, 10, shape=(4, 4, 4, 4, 4, 4, 4, 4, 4))
check(x, (0, 2, 4, 6, 8))
x = mx.random.randint(0, 10, shape=(4, 4, 4, 4, 4, 4, 4, 4, 4, 128))
x = x.transpose(1, 0, 2, 3, 4, 5, 6, 7, 8, 9)
check(x, (1, 3, 5, 7, 9))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main(failfast=True) unittest.main(failfast=True)