mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Reduce specializations (#1607)
* start of reduce specializations * fix all reduce * fix many dims * fix * non-jit tests clear * cleanup instantiations * cpu merges * change dim specializations * optimize * fix jit * fix jit * use higher precision for integer sum+prod * fixes
This commit is contained in:
parent
dcca0d7477
commit
0c5eea226b
@ -120,48 +120,56 @@ struct MinReduce {
|
||||
};
|
||||
|
||||
template <typename InT>
|
||||
void reduce_dispatch_out(
|
||||
void reduce_dispatch_and_or(
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType rtype,
|
||||
const std::vector<int>& axes) {
|
||||
switch (rtype) {
|
||||
case Reduce::And: {
|
||||
reduction_op<InT, bool>(in, out, axes, true, AndReduce());
|
||||
break;
|
||||
if (rtype == Reduce::And) {
|
||||
reduction_op<InT, bool>(in, out, axes, true, AndReduce());
|
||||
} else {
|
||||
reduction_op<InT, bool>(in, out, axes, false, OrReduce());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InT>
|
||||
void reduce_dispatch_sum_prod(
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType rtype,
|
||||
const std::vector<int>& axes) {
|
||||
if (rtype == Reduce::Sum) {
|
||||
auto op = [](auto y, auto x) { (*y) = (*y) + x; };
|
||||
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
|
||||
reduction_op<InT, int32_t>(in, out, axes, 0, op);
|
||||
} else {
|
||||
reduction_op<InT, InT>(in, out, axes, 0, op);
|
||||
}
|
||||
case Reduce::Or: {
|
||||
reduction_op<InT, bool>(in, out, axes, false, OrReduce());
|
||||
break;
|
||||
}
|
||||
case Reduce::Sum: {
|
||||
auto op = [](auto y, auto x) { (*y) = (*y) + x; };
|
||||
if (out.dtype() == int32) {
|
||||
// special case since the input type can be bool
|
||||
reduction_op<InT, int32_t>(in, out, axes, 0, op);
|
||||
} else {
|
||||
reduction_op<InT, InT>(in, out, axes, 0, op);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Reduce::Prod: {
|
||||
auto op = [](auto y, auto x) { (*y) *= x; };
|
||||
} else {
|
||||
auto op = [](auto y, auto x) { (*y) *= x; };
|
||||
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
|
||||
reduction_op<InT, int32_t>(in, out, axes, 1, op);
|
||||
} else {
|
||||
reduction_op<InT, InT>(in, out, axes, 1, op);
|
||||
break;
|
||||
}
|
||||
case Reduce::Max: {
|
||||
auto init = Limits<InT>::min;
|
||||
reduction_op<InT, InT>(in, out, axes, init, MaxReduce());
|
||||
break;
|
||||
}
|
||||
case Reduce::Min: {
|
||||
auto init = Limits<InT>::max;
|
||||
reduction_op<InT, InT>(in, out, axes, init, MinReduce());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InT>
|
||||
void reduce_dispatch_min_max(
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType rtype,
|
||||
const std::vector<int>& axes) {
|
||||
if (rtype == Reduce::Max) {
|
||||
auto init = Limits<InT>::min;
|
||||
reduction_op<InT, InT>(in, out, axes, init, MaxReduce());
|
||||
} else {
|
||||
auto init = Limits<InT>::max;
|
||||
reduction_op<InT, InT>(in, out, axes, init, MinReduce());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void nd_loop(
|
||||
@ -190,46 +198,114 @@ void nd_loop(
|
||||
void Reduce::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
reduce_dispatch_out<bool>(in, out, reduce_type_, axes_);
|
||||
switch (reduce_type_) {
|
||||
case Reduce::And:
|
||||
case Reduce::Or: {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
case int8:
|
||||
reduce_dispatch_and_or<int8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int16:
|
||||
case uint16:
|
||||
case float16:
|
||||
case bfloat16:
|
||||
reduce_dispatch_and_or<int16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint32:
|
||||
case int32:
|
||||
case float32:
|
||||
reduce_dispatch_and_or<int32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint64:
|
||||
case int64:
|
||||
case complex64:
|
||||
reduce_dispatch_and_or<int64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case uint8:
|
||||
reduce_dispatch_out<uint8_t>(in, out, reduce_type_, axes_);
|
||||
}
|
||||
case Reduce::Sum:
|
||||
case Reduce::Prod: {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
case int8:
|
||||
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int16:
|
||||
case uint16:
|
||||
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int32:
|
||||
case uint32:
|
||||
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int64:
|
||||
case uint64:
|
||||
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float16:
|
||||
reduce_dispatch_sum_prod<float16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case bfloat16:
|
||||
reduce_dispatch_sum_prod<bfloat16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float32:
|
||||
reduce_dispatch_sum_prod<float>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case complex64:
|
||||
reduce_dispatch_sum_prod<complex64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case uint16:
|
||||
reduce_dispatch_out<uint16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint32:
|
||||
reduce_dispatch_out<uint32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint64:
|
||||
reduce_dispatch_out<uint64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int8:
|
||||
reduce_dispatch_out<uint8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int16:
|
||||
reduce_dispatch_out<uint16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int32:
|
||||
reduce_dispatch_out<int32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int64:
|
||||
reduce_dispatch_out<int64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float16:
|
||||
reduce_dispatch_out<float16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float32:
|
||||
reduce_dispatch_out<float>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case bfloat16:
|
||||
reduce_dispatch_out<bfloat16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case complex64:
|
||||
reduce_dispatch_out<complex64_t>(in, out, reduce_type_, axes_);
|
||||
}
|
||||
case Reduce::Max:
|
||||
case Reduce::Min: {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
reduce_dispatch_min_max<bool>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint8:
|
||||
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint16:
|
||||
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint32:
|
||||
reduce_dispatch_min_max<uint32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint64:
|
||||
reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int8:
|
||||
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int16:
|
||||
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int32:
|
||||
reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int64:
|
||||
reduce_dispatch_min_max<int64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float16:
|
||||
reduce_dispatch_min_max<float16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float32:
|
||||
reduce_dispatch_min_max<float>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case bfloat16:
|
||||
reduce_dispatch_min_max<bfloat16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case complex64:
|
||||
reduce_dispatch_min_max<complex64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,5 +1,4 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/metal/jit/arange.h"
|
||||
#include "mlx/backend/metal/jit/gemv_masked.h"
|
||||
@ -338,17 +337,17 @@ MTL::ComputePipelineState* get_reduce_init_kernel(
|
||||
const std::string& kernel_name,
|
||||
const std::string& func_name,
|
||||
const std::string& op_name,
|
||||
const array& out) {
|
||||
const Dtype& out_type) {
|
||||
auto lib = d.get_library(kernel_name, [&]() {
|
||||
std::ostringstream kernel_source;
|
||||
std::string op_type = op_name;
|
||||
op_type[0] = std::toupper(op_name[0]);
|
||||
auto out_type = get_type_string(out.dtype());
|
||||
std::string op = op_type + "<" + out_type + ">";
|
||||
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce();
|
||||
kernel_source << get_template_definition(
|
||||
kernel_name, func_name, out_type, op);
|
||||
return kernel_source.str();
|
||||
auto out_t = get_type_string(out_type);
|
||||
std::string op = op_type + "<" + out_t + ">";
|
||||
std::string kernel_source = metal::utils();
|
||||
kernel_source += metal::reduce_utils();
|
||||
kernel_source += metal::reduce();
|
||||
kernel_source += get_template_definition(kernel_name, func_name, out_t, op);
|
||||
return kernel_source;
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
@ -358,30 +357,31 @@ MTL::ComputePipelineState* get_reduce_kernel(
|
||||
const std::string& kernel_name,
|
||||
const std::string& func_name,
|
||||
const std::string& op_name,
|
||||
const array& in,
|
||||
const array& out,
|
||||
const Dtype& in_type,
|
||||
const Dtype& out_type,
|
||||
const std::string& idx_t,
|
||||
int ndim /* = -1 */,
|
||||
int bm /* = -1 */,
|
||||
int bn /* = -1 */) {
|
||||
auto lib = d.get_library(kernel_name, [&]() {
|
||||
std::string op_type = op_name;
|
||||
op_type[0] = std::toupper(op_name[0]);
|
||||
std::ostringstream kernel_source;
|
||||
auto in_type = get_type_string(in.dtype());
|
||||
auto out_type = get_type_string(out.dtype());
|
||||
std::string op = op_type + "<" + out_type + ">";
|
||||
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce();
|
||||
auto in_t = get_type_string(in_type);
|
||||
auto out_t = get_type_string(out_type);
|
||||
std::string op = op_type + "<" + out_t + ">";
|
||||
std::string kernel_source = metal::utils();
|
||||
concatenate(kernel_source, metal::reduce_utils(), metal::reduce());
|
||||
if (bm >= 0) {
|
||||
kernel_source << get_template_definition(
|
||||
kernel_name, func_name, in_type, out_type, op, ndim, bm, bn);
|
||||
kernel_source += get_template_definition(
|
||||
kernel_name, func_name, in_t, out_t, op, idx_t, ndim, bm, bn);
|
||||
} else if (ndim >= 0) {
|
||||
kernel_source << get_template_definition(
|
||||
kernel_name, func_name, in_type, out_type, op, ndim);
|
||||
kernel_source += get_template_definition(
|
||||
kernel_name, func_name, in_t, out_t, op, idx_t, ndim);
|
||||
} else {
|
||||
kernel_source << get_template_definition(
|
||||
kernel_name, func_name, in_type, out_type, op);
|
||||
kernel_source += get_template_definition(
|
||||
kernel_name, func_name, in_t, out_t, op, idx_t);
|
||||
}
|
||||
return kernel_source.str();
|
||||
return kernel_source;
|
||||
});
|
||||
auto st = d.get_kernel(kernel_name, lib);
|
||||
return st;
|
||||
|
@ -81,15 +81,16 @@ MTL::ComputePipelineState* get_reduce_init_kernel(
|
||||
const std::string& kernel_name,
|
||||
const std::string& func_name,
|
||||
const std::string& op_name,
|
||||
const array& out);
|
||||
const Dtype& out_type);
|
||||
|
||||
MTL::ComputePipelineState* get_reduce_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& func_name,
|
||||
const std::string& op_name,
|
||||
const array& in,
|
||||
const array& out,
|
||||
const Dtype& in_type,
|
||||
const Dtype& out_type,
|
||||
const std::string& idx_t,
|
||||
int ndim = -1,
|
||||
int bm = -1,
|
||||
int bn = -1);
|
||||
|
@ -10,186 +10,156 @@
|
||||
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
||||
#include "mlx/backend/metal/kernels/reduce.h"
|
||||
|
||||
#define instantiate_reduce_helper_floats(inst_f, name, op) \
|
||||
inst_f(name, float16, half, op) \
|
||||
inst_f(name, float32, float, op) \
|
||||
inst_f(name, bfloat16, bfloat16_t, op)
|
||||
#define instantiate_init_reduce(name, tname, type, op) \
|
||||
instantiate_kernel("init_reduce_" #name #tname, init_reduce, type, op<type>)
|
||||
|
||||
#define instantiate_reduce_helper_uints(inst_f, name, op) \
|
||||
inst_f(name, uint8, uint8_t, op) \
|
||||
inst_f(name, uint16, uint16_t, op) \
|
||||
inst_f(name, uint32, uint32_t, op)
|
||||
instantiate_init_reduce(and, bool_, bool, And)
|
||||
instantiate_init_reduce(or, bool_, bool, Or)
|
||||
|
||||
#define instantiate_reduce_helper_ints(inst_f, name, op) \
|
||||
inst_f(name, int8, int8_t, op) \
|
||||
inst_f(name, int16, int16_t, op) \
|
||||
inst_f(name, int32, int32_t, op)
|
||||
#define instantiate_init_sum_prod(name, op) \
|
||||
instantiate_init_reduce(name, int32, int32_t, op) \
|
||||
instantiate_init_reduce(name, int64, int64_t, op) \
|
||||
instantiate_init_reduce(name, float16, float16_t, op) \
|
||||
instantiate_init_reduce(name, bfloat16, bfloat16_t, op) \
|
||||
instantiate_init_reduce(name, float32, float, op) \
|
||||
instantiate_init_reduce(name, complex64, complex64_t, op)
|
||||
|
||||
#define instantiate_reduce_helper_64b(inst_f, name, op) \
|
||||
inst_f(name, int64, int64_t, op) \
|
||||
inst_f(name, uint64, uint64_t, op) \
|
||||
inst_f(name, complex64, complex64_t, op)
|
||||
instantiate_init_sum_prod(sum, Sum)
|
||||
instantiate_init_sum_prod(prod, Prod)
|
||||
|
||||
#define instantiate_reduce_helper_types(inst_f, name, op) \
|
||||
instantiate_reduce_helper_floats(inst_f, name, op) \
|
||||
instantiate_reduce_helper_uints(inst_f, name, op) \
|
||||
instantiate_reduce_helper_ints(inst_f, name, op)
|
||||
#define instantiate_init_min_max(name, op) \
|
||||
instantiate_init_reduce(name, bool_, bool, op) \
|
||||
instantiate_init_reduce(name, int8, int8_t, op) \
|
||||
instantiate_init_reduce(name, int16, int16_t, op) \
|
||||
instantiate_init_reduce(name, int32, int32_t, op) \
|
||||
instantiate_init_reduce(name, int64, int64_t, op) \
|
||||
instantiate_init_reduce(name, uint8, uint8_t, op) \
|
||||
instantiate_init_reduce(name, uint16, uint16_t, op) \
|
||||
instantiate_init_reduce(name, uint32, uint32_t, op) \
|
||||
instantiate_init_reduce(name, uint64, uint64_t, op) \
|
||||
instantiate_init_reduce(name, float16, float16_t, op) \
|
||||
instantiate_init_reduce(name, bfloat16, bfloat16_t, op) \
|
||||
instantiate_init_reduce(name, float32, float, op) \
|
||||
instantiate_init_reduce(name, complex64, complex64_t, op)
|
||||
|
||||
#define instantiate_reduce_ops(inst_f, type_f) \
|
||||
type_f(inst_f, sum, Sum) \
|
||||
type_f(inst_f, prod, Prod) \
|
||||
type_f(inst_f, min, Min) \
|
||||
type_f(inst_f, max, Max)
|
||||
|
||||
// Special case for bool reductions
|
||||
#define instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, tname, itype, otype, op) \
|
||||
inst_f(name##tname, itype, otype, op)
|
||||
|
||||
#define instantiate_reduce_from_types(inst_f, name, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, bool_, bool, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, uint8, uint8_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, uint16, uint16_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, uint32, uint32_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, uint64, uint64_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, int8, int8_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, int16, int16_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, int32, int32_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, int64, int64_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, float16, half, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, \
|
||||
name, \
|
||||
float32, \
|
||||
float, \
|
||||
otype, \
|
||||
op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, \
|
||||
name, \
|
||||
bfloat16, \
|
||||
bfloat16_t, \
|
||||
otype, \
|
||||
op)
|
||||
|
||||
#define instantiate_init_reduce(name, otype, op) \
|
||||
instantiate_kernel("init_reduce_" #name, \
|
||||
init_reduce, \
|
||||
otype, op)
|
||||
|
||||
#define instantiate_init_reduce_helper(name, tname, type, op) \
|
||||
instantiate_init_reduce(name##tname, type, op<type>)
|
||||
|
||||
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_64b)
|
||||
|
||||
instantiate_init_reduce(andbool_, bool, And<bool>)
|
||||
instantiate_init_reduce(orbool_, bool, Or<bool>)
|
||||
instantiate_init_min_max(min, Min)
|
||||
instantiate_init_min_max(max, Max)
|
||||
|
||||
#define instantiate_all_reduce(name, itype, otype, op) \
|
||||
instantiate_kernel("all_reduce_" #name, \
|
||||
all_reduce, \
|
||||
itype, otype, op)
|
||||
|
||||
#define instantiate_same_all_reduce_helper(name, tname, type, op) \
|
||||
instantiate_all_reduce(name##tname, type, type, op<type>)
|
||||
#define instantiate_col_reduce_small(name, itype, otype, op, dim) \
|
||||
instantiate_kernel("col_reduce_small_" #dim "_reduce_" #name, \
|
||||
col_reduce_small, \
|
||||
itype, otype, op, uint, dim) \
|
||||
instantiate_kernel("col_reduce_longcolumn_" #dim "_reduce_" #name, \
|
||||
col_reduce_longcolumn, \
|
||||
itype, otype, op, uint, dim) \
|
||||
instantiate_kernel("col_reduce_small_large_" #dim "_reduce_" #name, \
|
||||
col_reduce_small, \
|
||||
itype, otype, op, size_t, dim) \
|
||||
instantiate_kernel("col_reduce_longcolumn_large_" #dim "_reduce_" #name, \
|
||||
col_reduce_longcolumn, \
|
||||
itype, otype, op, size_t, dim)
|
||||
|
||||
instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_64b)
|
||||
#define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \
|
||||
instantiate_kernel("col_reduce_looped_" #dim "_" #bm "_" #bn "_reduce_" #name, \
|
||||
col_reduce_looped, \
|
||||
itype, otype, op, uint, dim, bm, bn) \
|
||||
instantiate_kernel("col_reduce_looped_large_" #dim "_" #bm "_" #bn "_reduce_" #name, \
|
||||
col_reduce_looped, \
|
||||
itype, otype, op, size_t, dim, bm, bn)
|
||||
|
||||
instantiate_reduce_from_types(instantiate_all_reduce, and, bool, And<bool>)
|
||||
instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or<bool>)
|
||||
|
||||
// special case bool with larger output type
|
||||
instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
|
||||
#define instantiate_col_reduce_small(name, itype, otype, op, dim) \
|
||||
instantiate_kernel("col_reduce_small_" #dim "_reduce_" #name, \
|
||||
col_reduce_small, \
|
||||
itype, otype, op, dim) \
|
||||
instantiate_kernel("col_reduce_longcolumn_" #dim "_reduce_" #name, \
|
||||
col_reduce_longcolumn, \
|
||||
itype, otype, op, dim)
|
||||
|
||||
#define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \
|
||||
instantiate_kernel("col_reduce_looped_" #dim "_" #bm "_" #bn "_reduce_" #name, \
|
||||
col_reduce_looped, \
|
||||
itype, otype, op, dim, bm, bn)
|
||||
|
||||
#define instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, bm, bn) \
|
||||
instantiate_kernel("col_reduce_2pass_" #dim "_" #bm "_" #bn "_reduce_" #name, \
|
||||
col_reduce_2pass, \
|
||||
itype, otype, op, dim, bm, bn)
|
||||
#define instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, bm, bn) \
|
||||
instantiate_kernel("col_reduce_2pass_" #dim "_" #bm "_" #bn "_reduce_" #name, \
|
||||
col_reduce_2pass, \
|
||||
itype, otype, op, uint, dim, bm, bn) \
|
||||
instantiate_kernel("col_reduce_2pass_large_" #dim "_" #bm "_" #bn "_reduce_" #name, \
|
||||
col_reduce_2pass, \
|
||||
itype, otype, op, size_t, dim, bm, bn)
|
||||
|
||||
#define instantiate_col_reduce_looped(name, itype, otype, op, dim) \
|
||||
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 32, 32) \
|
||||
instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, 32, 32)
|
||||
|
||||
#define instantiate_col_reduce_general(name, itype, otype, op) \
|
||||
instantiate_col_reduce_small(name, itype, otype, op, 0) \
|
||||
instantiate_col_reduce_small(name, itype, otype, op, 1) \
|
||||
instantiate_col_reduce_small(name, itype, otype, op, 2) \
|
||||
instantiate_col_reduce_small(name, itype, otype, op, 3) \
|
||||
instantiate_col_reduce_small(name, itype, otype, op, 4) \
|
||||
instantiate_col_reduce_looped(name, itype, otype, op, 0) \
|
||||
instantiate_col_reduce_small(name, itype, otype, op, 5) \
|
||||
instantiate_col_reduce_looped(name, itype, otype, op, 1) \
|
||||
instantiate_col_reduce_looped(name, itype, otype, op, 2) \
|
||||
instantiate_col_reduce_looped(name, itype, otype, op, 3) \
|
||||
instantiate_col_reduce_looped(name, itype, otype, op, 4)
|
||||
instantiate_col_reduce_looped(name, itype, otype, op, 5)
|
||||
|
||||
#define instantiate_same_col_reduce_helper(name, tname, type, op) \
|
||||
instantiate_col_reduce_general(name##tname, type, type, op<type>)
|
||||
#define instantiate_row_reduce_small(name, itype, otype, op, dim) \
|
||||
instantiate_kernel("row_reduce_small_" #dim "_reduce_" #name, \
|
||||
row_reduce_small, \
|
||||
itype, otype, op, uint, dim) \
|
||||
instantiate_kernel("row_reduce_small_large_" #dim "_reduce_" #name, \
|
||||
row_reduce_small, \
|
||||
itype, otype, op, size_t, dim)
|
||||
|
||||
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_64b)
|
||||
|
||||
instantiate_col_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_general, and, bool, And<bool>)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or<bool>)
|
||||
|
||||
#define instantiate_row_reduce_small(name, itype, otype, op, dim) \
|
||||
instantiate_kernel("row_reduce_small_" #dim "_reduce_" #name, \
|
||||
row_reduce_small, \
|
||||
itype, otype, op, dim)
|
||||
|
||||
#define instantiate_row_reduce_looped(name, itype, otype, op, dim) \
|
||||
instantiate_kernel("row_reduce_looped_" #dim "_reduce_" #name, \
|
||||
row_reduce_looped, \
|
||||
itype, otype, op, dim)
|
||||
#define instantiate_row_reduce_looped(name, itype, otype, op, dim) \
|
||||
instantiate_kernel("row_reduce_looped_" #dim "_reduce_" #name, \
|
||||
row_reduce_looped, \
|
||||
itype, otype, op, uint, dim) \
|
||||
instantiate_kernel("row_reduce_looped_large_" #dim "_reduce_" #name, \
|
||||
row_reduce_looped, \
|
||||
itype, otype, op, size_t, dim)
|
||||
|
||||
#define instantiate_row_reduce_general(name, itype, otype, op) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op, 0) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op, 1) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op, 2) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op, 3) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op, 4) \
|
||||
instantiate_row_reduce_looped(name, itype, otype, op, 0) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op, 5) \
|
||||
instantiate_row_reduce_looped(name, itype, otype, op, 1) \
|
||||
instantiate_row_reduce_looped(name, itype, otype, op, 2) \
|
||||
instantiate_row_reduce_looped(name, itype, otype, op, 3) \
|
||||
instantiate_row_reduce_looped(name, itype, otype, op, 4) \
|
||||
instantiate_row_reduce_looped(name, itype, otype, op, 5) \
|
||||
instantiate_kernel("row_reduce_simple_" #name, \
|
||||
row_reduce_simple, \
|
||||
itype, otype, op)
|
||||
|
||||
#define instantiate_same_row_reduce_helper(name, tname, type, op) \
|
||||
instantiate_row_reduce_general(name##tname, type, type, op<type>)
|
||||
#define instantiate_reduce_functions(name, tname, itype, otype, op) \
|
||||
instantiate_all_reduce(name##tname, itype, otype, op<otype>) \
|
||||
instantiate_row_reduce_general(name##tname, itype, otype, op<otype>) \
|
||||
instantiate_col_reduce_general(name##tname, itype, otype, op<otype>)
|
||||
|
||||
instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_64b)
|
||||
#define instantiate_and_or(name, op) \
|
||||
instantiate_reduce_functions(name, bool_, bool, bool, op) \
|
||||
instantiate_reduce_functions(name, int16, int16_t, bool, op) \
|
||||
instantiate_reduce_functions(name, int32, int32_t, bool, op) \
|
||||
instantiate_reduce_functions(name, int64, int64_t, bool, op)
|
||||
|
||||
instantiate_reduce_from_types(instantiate_row_reduce_general, and, bool, And<bool>)
|
||||
instantiate_reduce_from_types(instantiate_row_reduce_general, or, bool, Or<bool>)
|
||||
instantiate_and_or(and, And)
|
||||
instantiate_and_or(or, Or)
|
||||
|
||||
instantiate_row_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
#define instantiate_sum_prod(name, op) \
|
||||
instantiate_reduce_functions(name, int8, int8_t, int32_t, op) \
|
||||
instantiate_reduce_functions(name, int16, int16_t, int32_t, op) \
|
||||
instantiate_reduce_functions(name, int32, int32_t, int32_t, op) \
|
||||
instantiate_reduce_functions(name, int64, int64_t, int64_t, op) \
|
||||
instantiate_reduce_functions(name, float16, float16_t, float16_t, op) \
|
||||
instantiate_reduce_functions(name, bfloat16, bfloat16_t, bfloat16_t, op) \
|
||||
instantiate_reduce_functions(name, float32, float, float, op) \
|
||||
instantiate_reduce_functions(name, complex64, complex64_t, complex64_t, op)
|
||||
|
||||
instantiate_sum_prod(sum, Sum)
|
||||
instantiate_sum_prod(prod, Prod)
|
||||
|
||||
#define instantiate_min_max(name, op) \
|
||||
instantiate_reduce_functions(name, int8, int8_t, int8_t, op) \
|
||||
instantiate_reduce_functions(name, int16, int16_t, int16_t, op) \
|
||||
instantiate_reduce_functions(name, int32, int32_t, int32_t, op) \
|
||||
instantiate_reduce_functions(name, int64, int64_t, int64_t, op) \
|
||||
instantiate_reduce_functions(name, uint8, uint8_t, uint8_t, op) \
|
||||
instantiate_reduce_functions(name, uint16, uint16_t, uint16_t, op) \
|
||||
instantiate_reduce_functions(name, uint32, uint32_t, uint32_t, op) \
|
||||
instantiate_reduce_functions(name, uint64, uint64_t, uint64_t, op) \
|
||||
instantiate_reduce_functions(name, float16, float16_t, float16_t, op) \
|
||||
instantiate_reduce_functions(name, bfloat16, bfloat16_t, bfloat16_t, op) \
|
||||
instantiate_reduce_functions(name, float32, float, float, op) \
|
||||
instantiate_reduce_functions(name, complex64, complex64_t, complex64_t, op)
|
||||
|
||||
instantiate_min_max(min, Min)
|
||||
instantiate_min_max(max, Max)
|
||||
// clang-format on
|
||||
|
@ -1,6 +1,11 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
typename IdxT = int64_t,
|
||||
int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void all_reduce(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
@ -16,10 +21,10 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
threadgroup U shared_vals[simd_size];
|
||||
|
||||
U total = Op::init;
|
||||
int64_t start_idx = gid.y * row_size;
|
||||
int64_t actual_row =
|
||||
IdxT start_idx = gid.y * IdxT(row_size);
|
||||
IdxT actual_row =
|
||||
(start_idx + row_size <= in_size) ? row_size : in_size - start_idx;
|
||||
int64_t blocks = actual_row / (lsize.x * N_READS);
|
||||
IdxT blocks = actual_row / (lsize.x * N_READS);
|
||||
int extra = actual_row - blocks * (lsize.x * N_READS);
|
||||
extra -= lid.x * N_READS;
|
||||
start_idx += lid.x * N_READS;
|
||||
@ -30,7 +35,7 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
extra = 0;
|
||||
}
|
||||
|
||||
for (int64_t b = 0; b < blocks; b++) {
|
||||
for (IdxT b = 0; b < blocks; b++) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
total = op(static_cast<U>(in[i]), total);
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
template <typename T, typename U, typename Op, int NDIMS>
|
||||
template <typename T, typename U, typename Op, typename IdxT, int NDIMS>
|
||||
[[kernel]] void col_reduce_small(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
@ -19,7 +19,7 @@ template <typename T, typename U, typename Op, int NDIMS>
|
||||
uint3 lsize [[threads_per_threadgroup]]) {
|
||||
constexpr int n_reads = 4;
|
||||
Op op;
|
||||
looped_elem_to_loc<NDIMS> loop;
|
||||
LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
|
||||
const device T* row;
|
||||
|
||||
U totals[n_reads];
|
||||
@ -27,20 +27,20 @@ template <typename T, typename U, typename Op, int NDIMS>
|
||||
totals[i] = Op::init;
|
||||
}
|
||||
|
||||
size_t column = size_t(gid.x) * lsize.x * n_reads + lid.x * n_reads;
|
||||
IdxT column = IdxT(gid.x) * lsize.x * n_reads + lid.x * n_reads;
|
||||
if (column >= reduction_stride) {
|
||||
return;
|
||||
}
|
||||
bool safe = column + n_reads <= reduction_stride;
|
||||
|
||||
size_t out_idx = gid.y + gsize.y * size_t(gid.z);
|
||||
size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
|
||||
IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);
|
||||
IdxT in_idx = elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim);
|
||||
in += in_idx + column;
|
||||
|
||||
size_t total_rows = non_col_reductions * reduction_size;
|
||||
IdxT total_rows = IdxT(non_col_reductions) * IdxT(reduction_size);
|
||||
loop.next(lid.y, reduce_shape, reduce_strides);
|
||||
for (size_t r = lid.y; r < total_rows; r += lsize.y) {
|
||||
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
|
||||
for (IdxT r = lid.y; r < total_rows; r += lsize.y) {
|
||||
row = in + loop.location();
|
||||
if (safe) {
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
totals[i] = op(static_cast<U>(row[i]), totals[i]);
|
||||
@ -80,7 +80,7 @@ template <typename T, typename U, typename Op, int NDIMS>
|
||||
}
|
||||
|
||||
if (lid.y == 0) {
|
||||
out += out_idx * reduction_stride + column;
|
||||
out += out_idx * IdxT(reduction_stride) + column;
|
||||
if (safe) {
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
out[i] = totals[i];
|
||||
@ -93,7 +93,7 @@ template <typename T, typename U, typename Op, int NDIMS>
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int NDIMS>
|
||||
template <typename T, typename U, typename Op, typename IdxT, int NDIMS>
|
||||
[[kernel]] void col_reduce_longcolumn(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
@ -112,19 +112,19 @@ template <typename T, typename U, typename Op, int NDIMS>
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 lsize [[threads_per_threadgroup]]) {
|
||||
Op op;
|
||||
looped_elem_to_loc<NDIMS> loop;
|
||||
LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
|
||||
const device T* row;
|
||||
|
||||
size_t out_idx = gid.x + gsize.x * size_t(gid.y);
|
||||
size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
|
||||
IdxT out_idx = gid.x + gsize.x * IdxT(gid.y);
|
||||
IdxT in_idx = elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim);
|
||||
in += in_idx + lid.x;
|
||||
|
||||
U total = Op::init;
|
||||
size_t total_rows = non_col_reductions * reduction_size;
|
||||
IdxT total_rows = IdxT(non_col_reductions) * IdxT(reduction_size);
|
||||
loop.next(gid.z * lsize.y + lid.y, reduce_shape, reduce_strides);
|
||||
for (size_t r = gid.z * lsize.y + lid.y; r < total_rows;
|
||||
for (IdxT r = gid.z * lsize.y + lid.y; r < total_rows;
|
||||
r += lsize.y * gsize.z) {
|
||||
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
|
||||
row = in + loop.location();
|
||||
total = op(static_cast<U>(*row), total);
|
||||
loop.next(lsize.y * gsize.z, reduce_shape, reduce_strides);
|
||||
}
|
||||
@ -136,7 +136,8 @@ template <typename T, typename U, typename Op, int NDIMS>
|
||||
for (uint i = 1; i < lsize.y; i++) {
|
||||
total = op(total, shared_vals[i * lsize.x + lid.x]);
|
||||
}
|
||||
out[gid.z * out_size + out_idx * reduction_stride + lid.x] = total;
|
||||
out[gid.z * IdxT(out_size) + out_idx * IdxT(reduction_stride) + lid.x] =
|
||||
total;
|
||||
}
|
||||
}
|
||||
|
||||
@ -151,7 +152,14 @@ template <typename T, typename U, typename Op, int NDIMS>
|
||||
* totals with a loop.
|
||||
* 7. Write them to the output
|
||||
*/
|
||||
template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
typename IdxT,
|
||||
int NDIMS,
|
||||
int BM,
|
||||
int BN>
|
||||
[[kernel]] void col_reduce_looped(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
@ -176,7 +184,7 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
|
||||
|
||||
threadgroup U shared_vals[BN * BM];
|
||||
U totals[n_reads];
|
||||
looped_elem_to_loc<NDIMS> loop;
|
||||
LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
|
||||
const device T* row;
|
||||
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
@ -185,17 +193,17 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
|
||||
|
||||
short lid = simd_group_id * simd_size + simd_lane_id;
|
||||
short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);
|
||||
size_t column = BN * gid.x + offset.x;
|
||||
IdxT column = BN * gid.x + offset.x;
|
||||
bool safe = column + n_reads <= reduction_stride;
|
||||
|
||||
size_t out_idx = gid.y + gsize.y * size_t(gid.z);
|
||||
size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
|
||||
IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);
|
||||
IdxT in_idx = elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim);
|
||||
in += in_idx + column;
|
||||
|
||||
size_t total = non_col_reductions * reduction_size;
|
||||
IdxT total = IdxT(non_col_reductions) * IdxT(reduction_size);
|
||||
loop.next(offset.y, reduce_shape, reduce_strides);
|
||||
for (size_t r = offset.y; r < total; r += BM) {
|
||||
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
|
||||
for (IdxT r = offset.y; r < total; r += BM) {
|
||||
row = in + loop.location();
|
||||
|
||||
if (safe) {
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
@ -235,8 +243,8 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
|
||||
|
||||
// Write the output.
|
||||
if (simd_lane_id == 0) {
|
||||
size_t out_column = BN * gid.x + out_offset.x;
|
||||
out += out_idx * reduction_stride + out_column;
|
||||
IdxT out_column = BN * gid.x + out_offset.x;
|
||||
out += out_idx * IdxT(reduction_stride) + out_column;
|
||||
if (out_column + n_outputs <= reduction_stride) {
|
||||
for (int i = 0; i < n_outputs; i++) {
|
||||
out[i] = totals[i];
|
||||
@ -269,7 +277,7 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
|
||||
|
||||
// Write the output.
|
||||
if (offset.y == 0) {
|
||||
out += out_idx * reduction_stride + column;
|
||||
out += out_idx * IdxT(reduction_stride) + column;
|
||||
if (safe) {
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
out[i] = totals[i];
|
||||
@ -283,7 +291,14 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
typename IdxT,
|
||||
int NDIMS,
|
||||
int BM,
|
||||
int BN>
|
||||
[[kernel]] void col_reduce_2pass(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
@ -312,7 +327,7 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
|
||||
|
||||
threadgroup U shared_vals[BN * BM];
|
||||
U totals[n_reads];
|
||||
looped_elem_to_loc<NDIMS> loop;
|
||||
LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
|
||||
const device T* row;
|
||||
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
@ -321,20 +336,19 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
|
||||
|
||||
short lid = simd_group_id * simd_size + simd_lane_id;
|
||||
short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);
|
||||
size_t column = BN * gid.x + offset.x;
|
||||
IdxT column = BN * gid.x + offset.x;
|
||||
bool safe = column + n_reads <= reduction_stride;
|
||||
|
||||
size_t full_idx = gid.y + gsize.y * size_t(gid.z);
|
||||
size_t block_idx = full_idx / out_size;
|
||||
size_t out_idx = full_idx % out_size;
|
||||
size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
|
||||
IdxT full_idx = gid.y + gsize.y * IdxT(gid.z);
|
||||
IdxT block_idx = full_idx / IdxT(out_size);
|
||||
IdxT out_idx = full_idx % IdxT(out_size);
|
||||
IdxT in_idx = elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim);
|
||||
in += in_idx + column;
|
||||
|
||||
size_t total = non_col_reductions * reduction_size;
|
||||
IdxT total = IdxT(non_col_reductions) * IdxT(reduction_size);
|
||||
loop.next(offset.y + block_idx * BM, reduce_shape, reduce_strides);
|
||||
for (size_t r = offset.y + block_idx * BM; r < total;
|
||||
r += outer_blocks * BM) {
|
||||
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
|
||||
for (IdxT r = offset.y + block_idx * BM; r < total; r += outer_blocks * BM) {
|
||||
row = in + loop.location();
|
||||
|
||||
if (safe) {
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
@ -369,8 +383,8 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
|
||||
|
||||
// Write the output.
|
||||
if (simd_lane_id == 0) {
|
||||
size_t out_column = BN * gid.x + out_offset.x;
|
||||
out += full_idx * reduction_stride + out_column;
|
||||
IdxT out_column = BN * gid.x + out_offset.x;
|
||||
out += full_idx * IdxT(reduction_stride) + out_column;
|
||||
if (out_column + n_outputs <= reduction_stride) {
|
||||
for (int i = 0; i < n_outputs; i++) {
|
||||
out[i] = totals[i];
|
||||
|
@ -193,6 +193,7 @@ template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
typename IdxT,
|
||||
int NDIMS,
|
||||
int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void row_reduce_small(
|
||||
@ -214,20 +215,20 @@ template <
|
||||
Op op;
|
||||
|
||||
U total_val = Op::init;
|
||||
looped_elem_to_loc<NDIMS> loop;
|
||||
LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
|
||||
|
||||
// Precompute some row reduction numbers
|
||||
const device T* row;
|
||||
int blocks = row_size / N_READS;
|
||||
int extra = row_size % N_READS;
|
||||
int blocks = IdxT(row_size) / N_READS;
|
||||
int extra = IdxT(row_size) % N_READS;
|
||||
|
||||
if ((non_row_reductions < 32 && row_size <= 8) || non_row_reductions <= 8) {
|
||||
// Simple loop over non_row_reductions and reduce the row in the thread.
|
||||
size_t out_idx = tid.x + tsize.y * size_t(tid.y);
|
||||
in += elem_to_loc(out_idx, shape, strides, ndim);
|
||||
IdxT out_idx = tid.x + tsize.y * IdxT(tid.y);
|
||||
in += elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim);
|
||||
|
||||
for (uint r = 0; r < non_row_reductions; r++) {
|
||||
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
|
||||
row = in + loop.location();
|
||||
thread_reduce<T, U, Op, N_READS>(total_val, row, blocks, extra);
|
||||
loop.next(reduce_shape, reduce_strides);
|
||||
}
|
||||
@ -236,13 +237,13 @@ template <
|
||||
} else {
|
||||
// Collaboratively reduce over non_row_reductions in the simdgroup. Each
|
||||
// thread reduces every 32nd row and then a simple simd reduce.
|
||||
size_t out_idx = gid.y + gsize.y * size_t(gid.z);
|
||||
in += elem_to_loc(out_idx, shape, strides, ndim);
|
||||
IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);
|
||||
in += elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim);
|
||||
|
||||
loop.next(simd_lane_id, reduce_shape, reduce_strides);
|
||||
|
||||
for (uint r = simd_lane_id; r < non_row_reductions; r += simd_size) {
|
||||
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
|
||||
row = in + loop.location();
|
||||
thread_reduce<T, U, Op, N_READS>(total_val, row, blocks, extra);
|
||||
loop.next(simd_size, reduce_shape, reduce_strides);
|
||||
}
|
||||
@ -259,6 +260,7 @@ template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
typename IdxT = size_t,
|
||||
int N_READS = REDUCE_N_READS,
|
||||
int N_WRITES = REDUCE_N_WRITES>
|
||||
[[kernel]] void row_reduce_simple(
|
||||
@ -277,15 +279,15 @@ template <
|
||||
U totals[N_WRITES];
|
||||
|
||||
// Move to the row
|
||||
size_t out_idx = N_WRITES * (gid.y + gsize.y * size_t(gid.z));
|
||||
IdxT out_idx = N_WRITES * (gid.y + gsize.y * IdxT(gid.z));
|
||||
if (out_idx + N_WRITES > out_size) {
|
||||
out_idx = out_size - N_WRITES;
|
||||
}
|
||||
in += out_idx * reduction_size;
|
||||
in += out_idx * IdxT(reduction_size);
|
||||
out += out_idx;
|
||||
|
||||
// Each thread reduces across the row
|
||||
int blocks = reduction_size / (lsize.x * N_READS);
|
||||
int blocks = IdxT(reduction_size) / (lsize.x * N_READS);
|
||||
int extra = reduction_size - blocks * (lsize.x * N_READS);
|
||||
per_thread_row_reduce<T, U, Op, N_READS, N_WRITES>(
|
||||
totals, in, reduction_size, blocks, extra, lsize.x, lid.x);
|
||||
@ -306,6 +308,7 @@ template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
typename IdxT,
|
||||
int NDIMS,
|
||||
int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void row_reduce_looped(
|
||||
@ -330,19 +333,20 @@ template <
|
||||
threadgroup U shared_vals[simd_size];
|
||||
U total = Op::init;
|
||||
|
||||
size_t out_idx = gid.y + gsize.y * size_t(gid.z);
|
||||
IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);
|
||||
|
||||
// lid.x * N_READS breaks the per_thread_row_reduce interface a bit. Maybe it
|
||||
// needs a small refactor.
|
||||
in += elem_to_loc(out_idx, shape, strides, ndim) + lid.x * N_READS;
|
||||
in += elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim) +
|
||||
lid.x * N_READS;
|
||||
|
||||
looped_elem_to_loc<NDIMS> loop;
|
||||
LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
|
||||
const device T* row;
|
||||
int blocks = row_size / (lsize.x * N_READS);
|
||||
int blocks = IdxT(row_size) / (lsize.x * N_READS);
|
||||
int extra = row_size - blocks * (lsize.x * N_READS);
|
||||
|
||||
for (size_t i = 0; i < non_row_reductions; i++) {
|
||||
row = in + loop.location(i, reduce_shape, reduce_strides, reduce_ndim);
|
||||
for (IdxT i = 0; i < non_row_reductions; i++) {
|
||||
row = in + loop.location();
|
||||
|
||||
// Each thread reduces across the row
|
||||
U row_total;
|
||||
|
@ -204,16 +204,21 @@ METAL_FUNC vec<IdxT, 3> elem_to_loc_3_nd(
|
||||
// Elem to loc in a loop utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int dim, typename offset_t = size_t>
|
||||
struct looped_elem_to_loc {
|
||||
looped_elem_to_loc<dim - 1, offset_t> inner_looper;
|
||||
offset_t offset{0};
|
||||
template <int DIM, typename OffsetT = size_t, bool General = true>
|
||||
struct LoopedElemToLoc {
|
||||
int dim;
|
||||
LoopedElemToLoc<DIM - 1, OffsetT, General> inner_looper;
|
||||
OffsetT offset{0};
|
||||
int index{0};
|
||||
|
||||
void next(const constant int* shape, const constant size_t* strides) {
|
||||
index++;
|
||||
offset += strides[dim - 1];
|
||||
LoopedElemToLoc(int dim) : dim(dim), inner_looper(dim - 1) {}
|
||||
|
||||
void next(const constant int* shape, const constant size_t* strides) {
|
||||
if (dim == 0) {
|
||||
return;
|
||||
}
|
||||
index++;
|
||||
offset += OffsetT(strides[dim - 1]);
|
||||
if (index >= shape[dim - 1]) {
|
||||
index = 0;
|
||||
inner_looper.next(shape, strides);
|
||||
@ -222,13 +227,21 @@ struct looped_elem_to_loc {
|
||||
}
|
||||
|
||||
void next(int n, const constant int* shape, const constant size_t* strides) {
|
||||
if (dim == 0) {
|
||||
return;
|
||||
}
|
||||
index += n;
|
||||
offset += n * strides[dim - 1];
|
||||
offset += n * OffsetT(strides[dim - 1]);
|
||||
|
||||
if (index >= shape[dim - 1]) {
|
||||
int extra = index - shape[dim - 1];
|
||||
if (extra >= shape[dim - 1]) {
|
||||
inner_looper.next(1 + extra / shape[dim - 1], shape, strides);
|
||||
extra = extra % shape[dim - 1];
|
||||
} else {
|
||||
inner_looper.next(shape, strides);
|
||||
}
|
||||
index = 0;
|
||||
inner_looper.next(shape, strides);
|
||||
offset = inner_looper.offset;
|
||||
if (extra > 0) {
|
||||
next(extra, shape, strides);
|
||||
@ -236,44 +249,61 @@ struct looped_elem_to_loc {
|
||||
}
|
||||
}
|
||||
|
||||
offset_t
|
||||
location(offset_t, const constant int*, const constant size_t*, int) {
|
||||
OffsetT location() {
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename offset_t>
|
||||
struct looped_elem_to_loc<1, offset_t> {
|
||||
offset_t offset{0};
|
||||
template <typename OffsetT>
|
||||
struct LoopedElemToLoc<1, OffsetT, true> {
|
||||
int dim;
|
||||
OffsetT offset{0};
|
||||
uint index{0};
|
||||
|
||||
LoopedElemToLoc(int dim) : dim(dim) {}
|
||||
|
||||
void next(const constant int* shape, const constant size_t* strides) {
|
||||
index++;
|
||||
if (dim > 1) {
|
||||
offset = elem_to_loc<size_t, OffsetT>(index, shape, strides, dim);
|
||||
} else {
|
||||
offset += OffsetT(strides[0]);
|
||||
}
|
||||
}
|
||||
|
||||
void next(int n, const constant int* shape, const constant size_t* strides) {
|
||||
index += n;
|
||||
if (dim > 1) {
|
||||
offset = elem_to_loc<size_t, OffsetT>(index, shape, strides, dim);
|
||||
} else {
|
||||
offset = index * OffsetT(strides[0]);
|
||||
}
|
||||
}
|
||||
|
||||
OffsetT location() {
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OffsetT>
|
||||
struct LoopedElemToLoc<1, OffsetT, false> {
|
||||
OffsetT offset{0};
|
||||
|
||||
LoopedElemToLoc(int) {}
|
||||
|
||||
void next(const constant int*, const constant size_t* strides) {
|
||||
offset += strides[0];
|
||||
offset += OffsetT(strides[0]);
|
||||
}
|
||||
|
||||
void next(int n, const constant int*, const constant size_t* strides) {
|
||||
offset += n * strides[0];
|
||||
offset += n * OffsetT(strides[0]);
|
||||
}
|
||||
|
||||
offset_t
|
||||
location(offset_t, const constant int*, const constant size_t*, int) {
|
||||
OffsetT location() {
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename offset_t>
|
||||
struct looped_elem_to_loc<0, offset_t> {
|
||||
void next(const constant int*, const constant size_t*) {}
|
||||
void next(int, const constant int*, const constant size_t*) {}
|
||||
|
||||
offset_t location(
|
||||
offset_t idx,
|
||||
const constant int* shape,
|
||||
const constant size_t* strides,
|
||||
int ndim) {
|
||||
return elem_to_loc(idx, shape, strides, ndim);
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Calculation utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1,3 +1,4 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
@ -99,7 +100,7 @@ MTL::ComputePipelineState* get_reduce_init_kernel(
|
||||
const std::string& kernel_name,
|
||||
const std::string&,
|
||||
const std::string&,
|
||||
const array&) {
|
||||
const Dtype&) {
|
||||
return d.get_kernel(kernel_name);
|
||||
}
|
||||
|
||||
@ -108,8 +109,9 @@ MTL::ComputePipelineState* get_reduce_kernel(
|
||||
const std::string& kernel_name,
|
||||
const std::string&,
|
||||
const std::string&,
|
||||
const array&,
|
||||
const array&,
|
||||
const Dtype&,
|
||||
const Dtype&,
|
||||
const std::string&,
|
||||
int,
|
||||
int,
|
||||
int) {
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
@ -202,6 +201,16 @@ inline bool is_64b_dtype(Dtype dtype) {
|
||||
return dtype == int64 || dtype == uint64 || dtype == complex64;
|
||||
}
|
||||
|
||||
inline int get_kernel_reduce_ndim(int reduce_ndim) {
|
||||
if (reduce_ndim <= 1) {
|
||||
return 1;
|
||||
} else if (reduce_ndim == 2) {
|
||||
return 2;
|
||||
} else {
|
||||
return 5;
|
||||
}
|
||||
}
|
||||
|
||||
inline int threadgroup_size_from_row_size(int row_size) {
|
||||
// 1 simdgroup per row smallish rows
|
||||
if (row_size <= 512) {
|
||||
@ -233,16 +242,51 @@ inline auto output_grid_for_col_reduce(
|
||||
return get_2d_grid_dims(out_shape, out_strides);
|
||||
}
|
||||
|
||||
std::pair<Dtype, Dtype> remap_reduce_types(
|
||||
const array& in,
|
||||
const std::string& op_name) {
|
||||
if (op_name == "sum" || op_name == "prod") {
|
||||
if (issubdtype(in.dtype(), integer)) {
|
||||
switch (in.dtype().size()) {
|
||||
case 1:
|
||||
return {int8, int32};
|
||||
case 2:
|
||||
return {int16, int32};
|
||||
case 4:
|
||||
return {int32, int32};
|
||||
case 8:
|
||||
return {int64, int64};
|
||||
}
|
||||
}
|
||||
if (in.dtype() == bool_) {
|
||||
return {int8, int32};
|
||||
}
|
||||
return {in.dtype(), in.dtype()};
|
||||
} else if (op_name == "and" || op_name == "or") {
|
||||
if (in.dtype().size() == 1) {
|
||||
return {bool_, bool_};
|
||||
} else if (in.dtype().size() == 2) {
|
||||
return {int16, bool_};
|
||||
} else if (in.dtype().size() == 4) {
|
||||
return {int32, bool_};
|
||||
} else {
|
||||
return {int64, bool_};
|
||||
}
|
||||
}
|
||||
return {in.dtype(), in.dtype()};
|
||||
}
|
||||
|
||||
void init_reduce(
|
||||
array& out,
|
||||
const std::string& op_name,
|
||||
CommandEncoder& compute_encoder,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
std::ostringstream kname;
|
||||
auto [_, out_type] = remap_reduce_types(out, op_name);
|
||||
const std::string func_name = "init_reduce";
|
||||
kname << func_name << "_" << op_name << type_to_name(out);
|
||||
auto kernel = get_reduce_init_kernel(d, kname.str(), func_name, op_name, out);
|
||||
std::string kname = func_name;
|
||||
concatenate(kname, "_", op_name, type_to_name(out_type));
|
||||
auto kernel = get_reduce_init_kernel(d, kname, func_name, op_name, out_type);
|
||||
size_t nthreads = out.size();
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
@ -263,10 +307,12 @@ void all_reduce_dispatch(
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
// Set the kernel
|
||||
std::ostringstream kname;
|
||||
auto [in_type, out_type] = remap_reduce_types(in, op_name);
|
||||
const std::string func_name = "all_reduce";
|
||||
kname << func_name << "_" << op_name << type_to_name(in);
|
||||
auto kernel = get_reduce_kernel(d, kname.str(), func_name, op_name, in, out);
|
||||
std::string kname = func_name;
|
||||
concatenate(kname, "_", op_name, type_to_name(in_type));
|
||||
auto kernel = get_reduce_kernel(
|
||||
d, kname, func_name, op_name, in_type, out_type, "int64_t");
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
size_t in_size = in.size();
|
||||
@ -300,7 +346,7 @@ void all_reduce_dispatch(
|
||||
}
|
||||
|
||||
// Allocate an intermediate tensor to hold results if needed
|
||||
array intermediate({n_rows}, out.dtype(), nullptr, {});
|
||||
array intermediate({n_rows}, out_type, nullptr, {});
|
||||
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
|
||||
d.add_temporary(intermediate, s.index);
|
||||
|
||||
@ -318,10 +364,10 @@ void all_reduce_dispatch(
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
|
||||
// 2nd pass
|
||||
std::ostringstream kname_2nd_pass;
|
||||
kname_2nd_pass << func_name << "_" << op_name << type_to_name(intermediate);
|
||||
std::string kname_2nd_pass = func_name;
|
||||
concatenate(kname_2nd_pass, "_", op_name, type_to_name(intermediate));
|
||||
auto kernel_2nd_pass = get_reduce_kernel(
|
||||
d, kname_2nd_pass.str(), func_name, op_name, intermediate, out);
|
||||
d, kname_2nd_pass, func_name, op_name, out_type, out_type, "int64_t");
|
||||
compute_encoder.set_compute_pipeline_state(kernel_2nd_pass);
|
||||
size_t intermediate_size = n_rows;
|
||||
grid_dims = MTL::Size(threadgroup_2nd_pass, 1, 1);
|
||||
@ -343,12 +389,30 @@ void row_reduce_small(
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
// Set the kernel
|
||||
std::ostringstream kname;
|
||||
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
|
||||
int n = get_kernel_reduce_ndim(args.reduce_ndim);
|
||||
auto [in_type, out_type] = remap_reduce_types(in, op_name);
|
||||
const std::string func_name = "row_reduce_small";
|
||||
kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in);
|
||||
auto kernel =
|
||||
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n);
|
||||
std::string kname = func_name;
|
||||
bool large = in.size() > UINT32_MAX;
|
||||
if (large) {
|
||||
kname += "_large";
|
||||
}
|
||||
concatenate(
|
||||
kname,
|
||||
"_",
|
||||
std::to_string(n),
|
||||
"_reduce_",
|
||||
op_name,
|
||||
type_to_name(in_type));
|
||||
auto kernel = get_reduce_kernel(
|
||||
d,
|
||||
kname,
|
||||
func_name,
|
||||
op_name,
|
||||
in_type,
|
||||
out_type,
|
||||
large ? "size_t" : "uint",
|
||||
n);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Figure out the grid dims
|
||||
@ -381,10 +445,13 @@ void row_reduce_simple(
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
// Set the kernel
|
||||
std::ostringstream kname;
|
||||
auto [in_type, out_type] = remap_reduce_types(in, op_name);
|
||||
const std::string func_name = "row_reduce_simple";
|
||||
kname << func_name << "_" << op_name << type_to_name(in);
|
||||
auto kernel = get_reduce_kernel(d, kname.str(), func_name, op_name, in, out);
|
||||
std::string kname = func_name;
|
||||
concatenate(kname, "_", op_name, type_to_name(in_type));
|
||||
|
||||
auto kernel = get_reduce_kernel(
|
||||
d, kname, func_name, op_name, in_type, out_type, "size_t");
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Figure out the grid dims
|
||||
@ -417,13 +484,32 @@ void row_reduce_looped(
|
||||
CommandEncoder& compute_encoder,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
auto [in_type, out_type] = remap_reduce_types(in, op_name);
|
||||
|
||||
// Set the kernel
|
||||
std::ostringstream kname;
|
||||
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
|
||||
int n = get_kernel_reduce_ndim(args.reduce_ndim);
|
||||
const std::string func_name = "row_reduce_looped";
|
||||
kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in);
|
||||
auto kernel =
|
||||
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n);
|
||||
std::string kname = func_name;
|
||||
bool large = in.size() > UINT32_MAX;
|
||||
if (large) {
|
||||
kname += "_large";
|
||||
}
|
||||
concatenate(
|
||||
kname,
|
||||
"_",
|
||||
std::to_string(n),
|
||||
"_reduce_",
|
||||
op_name,
|
||||
type_to_name(in_type));
|
||||
auto kernel = get_reduce_kernel(
|
||||
d,
|
||||
kname,
|
||||
func_name,
|
||||
op_name,
|
||||
in_type,
|
||||
out_type,
|
||||
large ? "size_t" : "uint",
|
||||
n);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Figure out the grid
|
||||
@ -475,6 +561,8 @@ void strided_reduce_small(
|
||||
CommandEncoder& compute_encoder,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
auto [in_type, out_type] = remap_reduce_types(in, op_name);
|
||||
|
||||
// Figure out the grid dims
|
||||
MTL::Size grid_dims, group_dims;
|
||||
|
||||
@ -483,12 +571,29 @@ void strided_reduce_small(
|
||||
args.reduce_strides.push_back(args.reduction_stride);
|
||||
args.reduce_ndim++;
|
||||
|
||||
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
|
||||
std::ostringstream kname;
|
||||
int n = get_kernel_reduce_ndim(args.reduce_ndim);
|
||||
const std::string func_name = "col_reduce_small";
|
||||
kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in);
|
||||
auto kernel =
|
||||
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n);
|
||||
std::string kname = func_name;
|
||||
bool large = in.size() > UINT32_MAX;
|
||||
if (large) {
|
||||
kname += "_large";
|
||||
}
|
||||
concatenate(
|
||||
kname,
|
||||
"_",
|
||||
std::to_string(n),
|
||||
"_reduce_",
|
||||
op_name,
|
||||
type_to_name(in_type));
|
||||
auto kernel = get_reduce_kernel(
|
||||
d,
|
||||
kname,
|
||||
func_name,
|
||||
op_name,
|
||||
in_type,
|
||||
out_type,
|
||||
large ? "size_t" : "uint",
|
||||
n);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
const int n_reads = 4;
|
||||
@ -522,6 +627,7 @@ void strided_reduce_longcolumn(
|
||||
CommandEncoder& compute_encoder,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
auto [in_type, out_type] = remap_reduce_types(in, op_name);
|
||||
size_t total_reduction_size = args.reduction_size * args.non_col_reductions;
|
||||
size_t outer_blocks = 32;
|
||||
if (total_reduction_size >= 32768) {
|
||||
@ -534,7 +640,7 @@ void strided_reduce_longcolumn(
|
||||
intermediate_shape.push_back(outer_blocks);
|
||||
intermediate_shape.insert(
|
||||
intermediate_shape.end(), out.shape().begin(), out.shape().end());
|
||||
array intermediate(std::move(intermediate_shape), out.dtype(), nullptr, {});
|
||||
array intermediate(std::move(intermediate_shape), out_type, nullptr, {});
|
||||
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
|
||||
d.add_temporary(intermediate, s.index);
|
||||
|
||||
@ -556,12 +662,29 @@ void strided_reduce_longcolumn(
|
||||
MTL::Size group_dims(threadgroup_x, threadgroup_y, 1);
|
||||
|
||||
// Set the kernel
|
||||
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
|
||||
std::ostringstream kname;
|
||||
const std::string func_name = "col_reduce_longcolumn";
|
||||
kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in);
|
||||
auto kernel =
|
||||
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n);
|
||||
int n = get_kernel_reduce_ndim(args.reduce_ndim);
|
||||
std::string func_name = "col_reduce_longcolumn";
|
||||
std::string kname = func_name;
|
||||
bool large = in.size() > UINT32_MAX;
|
||||
if (large) {
|
||||
kname += "_large";
|
||||
}
|
||||
concatenate(
|
||||
kname,
|
||||
"_",
|
||||
std::to_string(n),
|
||||
"_reduce_",
|
||||
op_name,
|
||||
type_to_name(in_type));
|
||||
auto kernel = get_reduce_kernel(
|
||||
d,
|
||||
kname,
|
||||
func_name,
|
||||
op_name,
|
||||
in_type,
|
||||
out_type,
|
||||
large ? "size_t" : "uint",
|
||||
n);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Launch
|
||||
@ -581,15 +704,21 @@ void strided_reduce_longcolumn(
|
||||
group_dims = MTL::Size(256, 1, 1);
|
||||
|
||||
// Set the 2nd kernel
|
||||
const std::string second_kernel = "col_reduce_looped_1_32_32_reduce_" +
|
||||
op_name + type_to_name(intermediate);
|
||||
func_name = "col_reduce_looped";
|
||||
kname = func_name;
|
||||
large = intermediate.size() > UINT32_MAX;
|
||||
if (large) {
|
||||
kname += "_large";
|
||||
}
|
||||
concatenate(kname, "_1_32_32_reduce_", op_name, type_to_name(intermediate));
|
||||
kernel = get_reduce_kernel(
|
||||
d,
|
||||
second_kernel,
|
||||
"col_reduce_looped",
|
||||
kname,
|
||||
func_name,
|
||||
op_name,
|
||||
intermediate,
|
||||
out,
|
||||
intermediate.dtype(),
|
||||
out_type,
|
||||
large ? "size_t" : "uint",
|
||||
1,
|
||||
32,
|
||||
32);
|
||||
@ -609,6 +738,8 @@ void strided_reduce_looped(
|
||||
CommandEncoder& compute_encoder,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
auto [in_type, out_type] = remap_reduce_types(in, op_name);
|
||||
|
||||
// Prepare the arguments for the kernel
|
||||
args.reduce_shape.push_back(args.reduction_size);
|
||||
args.reduce_strides.push_back(args.reduction_stride);
|
||||
@ -626,13 +757,35 @@ void strided_reduce_looped(
|
||||
MTL::Size group_dims(threadgroup_size, 1, 1);
|
||||
|
||||
// Set the kernel
|
||||
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
|
||||
std::ostringstream kname;
|
||||
const std::string func_name = "col_reduce_looped";
|
||||
kname << func_name << "_" << n << "_" << BM << "_" << BN << "_reduce_"
|
||||
<< op_name << type_to_name(in);
|
||||
auto kernel =
|
||||
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n, BM, BN);
|
||||
int n = get_kernel_reduce_ndim(args.reduce_ndim);
|
||||
std::string func_name = "col_reduce_looped";
|
||||
std::string kname = func_name;
|
||||
bool large = in.size() > UINT32_MAX;
|
||||
if (large) {
|
||||
kname += "_large";
|
||||
}
|
||||
concatenate(
|
||||
kname,
|
||||
"_",
|
||||
std::to_string(n),
|
||||
"_",
|
||||
std::to_string(BM),
|
||||
"_",
|
||||
std::to_string(BN),
|
||||
"_reduce_",
|
||||
op_name,
|
||||
type_to_name(in_type));
|
||||
auto kernel = get_reduce_kernel(
|
||||
d,
|
||||
kname,
|
||||
func_name,
|
||||
op_name,
|
||||
in_type,
|
||||
out_type,
|
||||
large ? "size_t" : "uint",
|
||||
n,
|
||||
BM,
|
||||
BN);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Launch
|
||||
@ -650,13 +803,15 @@ void strided_reduce_2pass(
|
||||
CommandEncoder& compute_encoder,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
auto [in_type, out_type] = remap_reduce_types(in, op_name);
|
||||
|
||||
// Prepare the temporary accumulator
|
||||
std::vector<int> intermediate_shape;
|
||||
intermediate_shape.reserve(out.ndim() + 1);
|
||||
intermediate_shape.push_back(32);
|
||||
intermediate_shape.insert(
|
||||
intermediate_shape.end(), out.shape().begin(), out.shape().end());
|
||||
array intermediate(std::move(intermediate_shape), out.dtype(), nullptr, {});
|
||||
array intermediate(std::move(intermediate_shape), out_type, nullptr, {});
|
||||
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
|
||||
d.add_temporary(intermediate, s.index);
|
||||
|
||||
@ -679,13 +834,35 @@ void strided_reduce_2pass(
|
||||
MTL::Size group_dims(threadgroup_size, 1, 1);
|
||||
|
||||
// Set the kernel
|
||||
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
|
||||
std::ostringstream kname;
|
||||
const std::string func_name = "col_reduce_2pass";
|
||||
kname << func_name << "_" << n << "_" << BM << "_" << BN << "_reduce_"
|
||||
<< op_name << type_to_name(in);
|
||||
auto kernel =
|
||||
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n, BM, BN);
|
||||
int n = get_kernel_reduce_ndim(args.reduce_ndim);
|
||||
std::string func_name = "col_reduce_2pass";
|
||||
std::string kname = func_name;
|
||||
bool large = in.size() > UINT32_MAX;
|
||||
if (large) {
|
||||
kname += "_large";
|
||||
}
|
||||
concatenate(
|
||||
kname,
|
||||
"_",
|
||||
std::to_string(n),
|
||||
"_",
|
||||
std::to_string(BM),
|
||||
"_",
|
||||
std::to_string(BN),
|
||||
"_reduce_",
|
||||
op_name,
|
||||
type_to_name(in_type));
|
||||
auto kernel = get_reduce_kernel(
|
||||
d,
|
||||
kname,
|
||||
func_name,
|
||||
op_name,
|
||||
in_type,
|
||||
out_type,
|
||||
large ? "size_t" : "uint",
|
||||
n,
|
||||
BM,
|
||||
BN);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Launch
|
||||
@ -703,15 +880,21 @@ void strided_reduce_2pass(
|
||||
grid_dims = MTL::Size(threadgroup_size * ((out.size() + BN - 1) / BN), 1, 1);
|
||||
|
||||
// Set the 2nd kernel
|
||||
const std::string second_kernel = "col_reduce_looped_1_32_32_reduce_" +
|
||||
op_name + type_to_name(intermediate);
|
||||
func_name = "col_reduce_looped";
|
||||
kname = func_name;
|
||||
large = intermediate.size() > UINT32_MAX;
|
||||
if (large) {
|
||||
kname += "_large";
|
||||
}
|
||||
concatenate(kname, "_1_32_32_reduce_", op_name, type_to_name(intermediate));
|
||||
kernel = get_reduce_kernel(
|
||||
d,
|
||||
second_kernel,
|
||||
"col_reduce_looped",
|
||||
kname,
|
||||
func_name,
|
||||
op_name,
|
||||
intermediate,
|
||||
out,
|
||||
intermediate.dtype(),
|
||||
out_type,
|
||||
large ? "size_t" : "uint",
|
||||
1,
|
||||
32,
|
||||
32);
|
||||
@ -780,7 +963,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
op_name = "sum";
|
||||
break;
|
||||
case Reduce::Prod:
|
||||
op_name = out.dtype() == bool_ ? "and" : "prod";
|
||||
op_name = "prod";
|
||||
break;
|
||||
case Reduce::Min:
|
||||
op_name = out.dtype() == bool_ ? "and" : "min";
|
||||
|
@ -6,9 +6,9 @@ using namespace mlx;
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::string type_to_name(const array& a) {
|
||||
std::string type_to_name(const Dtype& t) {
|
||||
std::string tname;
|
||||
switch (a.dtype()) {
|
||||
switch (t) {
|
||||
case bool_:
|
||||
tname = "bool_";
|
||||
break;
|
||||
@ -52,6 +52,10 @@ std::string type_to_name(const array& a) {
|
||||
return tname;
|
||||
}
|
||||
|
||||
std::string type_to_name(const array& a) {
|
||||
return type_to_name(a.dtype());
|
||||
}
|
||||
|
||||
MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2 /* = 10 */) {
|
||||
int pows[3] = {0, 0, 0};
|
||||
int sum = 0;
|
||||
|
@ -8,6 +8,7 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::string type_to_name(const Dtype& t);
|
||||
std::string type_to_name(const array& a);
|
||||
|
||||
// Compute the thread block dimensions which fit the given
|
||||
|
19
mlx/ops.cpp
19
mlx/ops.cpp
@ -1615,7 +1615,14 @@ array sum(
|
||||
}
|
||||
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
|
||||
compute_reduce_shape(axes, a.shape());
|
||||
auto out_type = a.dtype() == bool_ ? int32 : a.dtype();
|
||||
Dtype out_type = a.dtype();
|
||||
if (issubdtype(a.dtype(), signedinteger)) {
|
||||
out_type = a.dtype().size() <= 4 ? int32 : int64;
|
||||
} else if (issubdtype(a.dtype(), unsignedinteger)) {
|
||||
out_type = a.dtype().size() <= 4 ? uint32 : uint64;
|
||||
} else if (a.dtype() == bool_) {
|
||||
out_type = int32;
|
||||
}
|
||||
auto out = (is_noop)
|
||||
? astype(a, out_type, s)
|
||||
: array(
|
||||
@ -1760,11 +1767,19 @@ array prod(
|
||||
}
|
||||
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
|
||||
compute_reduce_shape(axes, a.shape());
|
||||
Dtype out_type = a.dtype();
|
||||
if (issubdtype(a.dtype(), signedinteger)) {
|
||||
out_type = a.dtype().size() <= 4 ? int32 : int64;
|
||||
} else if (issubdtype(a.dtype(), unsignedinteger)) {
|
||||
out_type = a.dtype().size() <= 4 ? uint32 : uint64;
|
||||
} else if (a.dtype() == bool_) {
|
||||
out_type = int32;
|
||||
}
|
||||
auto out = (is_noop)
|
||||
? a
|
||||
: array(
|
||||
std::move(out_shape),
|
||||
a.dtype(),
|
||||
out_type,
|
||||
std::make_shared<Reduce>(to_stream(s), Reduce::Prod, sorted_axes),
|
||||
{a});
|
||||
if (!keepdims) {
|
||||
|
@ -131,6 +131,28 @@ class TestReduce(mlx_tests.MLXTestCase):
|
||||
mxsum = y.sum().item()
|
||||
self.assertEqual(npsum, mxsum)
|
||||
|
||||
def test_many_reduction_axes(self):
|
||||
|
||||
def check(x, axes):
|
||||
expected = x
|
||||
for ax in axes:
|
||||
expected = mx.sum(expected, axis=ax, keepdims=True)
|
||||
out = mx.sum(x, axis=axes, keepdims=True)
|
||||
self.assertTrue(mx.array_equal(out, expected))
|
||||
|
||||
x = mx.random.randint(0, 10, shape=(4, 4, 4, 4, 4))
|
||||
check(x, (0, 2, 4))
|
||||
|
||||
x = mx.random.randint(0, 10, shape=(4, 4, 4, 4, 4, 4, 4))
|
||||
check(x, (0, 2, 4, 6))
|
||||
|
||||
x = mx.random.randint(0, 10, shape=(4, 4, 4, 4, 4, 4, 4, 4, 4))
|
||||
check(x, (0, 2, 4, 6, 8))
|
||||
|
||||
x = mx.random.randint(0, 10, shape=(4, 4, 4, 4, 4, 4, 4, 4, 4, 128))
|
||||
x = x.transpose(1, 0, 2, 3, 4, 5, 6, 7, 8, 9)
|
||||
check(x, (1, 3, 5, 7, 9))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(failfast=True)
|
||||
|
Loading…
Reference in New Issue
Block a user