mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Reduce specializations (#1607)
* start of reduce specializations * fix all reduce * fix many dims * fix * non-jit tests clear * cleanup instantiations * cpu merges * change dim specializations * optimize * fix jit * fix jit * use higher precision for integer sum+prod * fixes
This commit is contained in:
@@ -10,186 +10,156 @@
|
||||
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
||||
#include "mlx/backend/metal/kernels/reduce.h"
|
||||
|
||||
#define instantiate_reduce_helper_floats(inst_f, name, op) \
|
||||
inst_f(name, float16, half, op) \
|
||||
inst_f(name, float32, float, op) \
|
||||
inst_f(name, bfloat16, bfloat16_t, op)
|
||||
#define instantiate_init_reduce(name, tname, type, op) \
|
||||
instantiate_kernel("init_reduce_" #name #tname, init_reduce, type, op<type>)
|
||||
|
||||
#define instantiate_reduce_helper_uints(inst_f, name, op) \
|
||||
inst_f(name, uint8, uint8_t, op) \
|
||||
inst_f(name, uint16, uint16_t, op) \
|
||||
inst_f(name, uint32, uint32_t, op)
|
||||
instantiate_init_reduce(and, bool_, bool, And)
|
||||
instantiate_init_reduce(or, bool_, bool, Or)
|
||||
|
||||
#define instantiate_reduce_helper_ints(inst_f, name, op) \
|
||||
inst_f(name, int8, int8_t, op) \
|
||||
inst_f(name, int16, int16_t, op) \
|
||||
inst_f(name, int32, int32_t, op)
|
||||
#define instantiate_init_sum_prod(name, op) \
|
||||
instantiate_init_reduce(name, int32, int32_t, op) \
|
||||
instantiate_init_reduce(name, int64, int64_t, op) \
|
||||
instantiate_init_reduce(name, float16, float16_t, op) \
|
||||
instantiate_init_reduce(name, bfloat16, bfloat16_t, op) \
|
||||
instantiate_init_reduce(name, float32, float, op) \
|
||||
instantiate_init_reduce(name, complex64, complex64_t, op)
|
||||
|
||||
#define instantiate_reduce_helper_64b(inst_f, name, op) \
|
||||
inst_f(name, int64, int64_t, op) \
|
||||
inst_f(name, uint64, uint64_t, op) \
|
||||
inst_f(name, complex64, complex64_t, op)
|
||||
instantiate_init_sum_prod(sum, Sum)
|
||||
instantiate_init_sum_prod(prod, Prod)
|
||||
|
||||
#define instantiate_reduce_helper_types(inst_f, name, op) \
|
||||
instantiate_reduce_helper_floats(inst_f, name, op) \
|
||||
instantiate_reduce_helper_uints(inst_f, name, op) \
|
||||
instantiate_reduce_helper_ints(inst_f, name, op)
|
||||
#define instantiate_init_min_max(name, op) \
|
||||
instantiate_init_reduce(name, bool_, bool, op) \
|
||||
instantiate_init_reduce(name, int8, int8_t, op) \
|
||||
instantiate_init_reduce(name, int16, int16_t, op) \
|
||||
instantiate_init_reduce(name, int32, int32_t, op) \
|
||||
instantiate_init_reduce(name, int64, int64_t, op) \
|
||||
instantiate_init_reduce(name, uint8, uint8_t, op) \
|
||||
instantiate_init_reduce(name, uint16, uint16_t, op) \
|
||||
instantiate_init_reduce(name, uint32, uint32_t, op) \
|
||||
instantiate_init_reduce(name, uint64, uint64_t, op) \
|
||||
instantiate_init_reduce(name, float16, float16_t, op) \
|
||||
instantiate_init_reduce(name, bfloat16, bfloat16_t, op) \
|
||||
instantiate_init_reduce(name, float32, float, op) \
|
||||
instantiate_init_reduce(name, complex64, complex64_t, op)
|
||||
|
||||
#define instantiate_reduce_ops(inst_f, type_f) \
|
||||
type_f(inst_f, sum, Sum) \
|
||||
type_f(inst_f, prod, Prod) \
|
||||
type_f(inst_f, min, Min) \
|
||||
type_f(inst_f, max, Max)
|
||||
|
||||
// Special case for bool reductions
|
||||
#define instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, tname, itype, otype, op) \
|
||||
inst_f(name##tname, itype, otype, op)
|
||||
|
||||
#define instantiate_reduce_from_types(inst_f, name, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, bool_, bool, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, uint8, uint8_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, uint16, uint16_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, uint32, uint32_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, uint64, uint64_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, int8, int8_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, int16, int16_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, int32, int32_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, int64, int64_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, float16, half, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, \
|
||||
name, \
|
||||
float32, \
|
||||
float, \
|
||||
otype, \
|
||||
op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, \
|
||||
name, \
|
||||
bfloat16, \
|
||||
bfloat16_t, \
|
||||
otype, \
|
||||
op)
|
||||
|
||||
#define instantiate_init_reduce(name, otype, op) \
|
||||
instantiate_kernel("init_reduce_" #name, \
|
||||
init_reduce, \
|
||||
otype, op)
|
||||
|
||||
#define instantiate_init_reduce_helper(name, tname, type, op) \
|
||||
instantiate_init_reduce(name##tname, type, op<type>)
|
||||
|
||||
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_64b)
|
||||
|
||||
instantiate_init_reduce(andbool_, bool, And<bool>)
|
||||
instantiate_init_reduce(orbool_, bool, Or<bool>)
|
||||
instantiate_init_min_max(min, Min)
|
||||
instantiate_init_min_max(max, Max)
|
||||
|
||||
#define instantiate_all_reduce(name, itype, otype, op) \
|
||||
instantiate_kernel("all_reduce_" #name, \
|
||||
all_reduce, \
|
||||
itype, otype, op)
|
||||
|
||||
#define instantiate_same_all_reduce_helper(name, tname, type, op) \
|
||||
instantiate_all_reduce(name##tname, type, type, op<type>)
|
||||
#define instantiate_col_reduce_small(name, itype, otype, op, dim) \
|
||||
instantiate_kernel("col_reduce_small_" #dim "_reduce_" #name, \
|
||||
col_reduce_small, \
|
||||
itype, otype, op, uint, dim) \
|
||||
instantiate_kernel("col_reduce_longcolumn_" #dim "_reduce_" #name, \
|
||||
col_reduce_longcolumn, \
|
||||
itype, otype, op, uint, dim) \
|
||||
instantiate_kernel("col_reduce_small_large_" #dim "_reduce_" #name, \
|
||||
col_reduce_small, \
|
||||
itype, otype, op, size_t, dim) \
|
||||
instantiate_kernel("col_reduce_longcolumn_large_" #dim "_reduce_" #name, \
|
||||
col_reduce_longcolumn, \
|
||||
itype, otype, op, size_t, dim)
|
||||
|
||||
instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_64b)
|
||||
#define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \
|
||||
instantiate_kernel("col_reduce_looped_" #dim "_" #bm "_" #bn "_reduce_" #name, \
|
||||
col_reduce_looped, \
|
||||
itype, otype, op, uint, dim, bm, bn) \
|
||||
instantiate_kernel("col_reduce_looped_large_" #dim "_" #bm "_" #bn "_reduce_" #name, \
|
||||
col_reduce_looped, \
|
||||
itype, otype, op, size_t, dim, bm, bn)
|
||||
|
||||
instantiate_reduce_from_types(instantiate_all_reduce, and, bool, And<bool>)
|
||||
instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or<bool>)
|
||||
|
||||
// special case bool with larger output type
|
||||
instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
|
||||
#define instantiate_col_reduce_small(name, itype, otype, op, dim) \
|
||||
instantiate_kernel("col_reduce_small_" #dim "_reduce_" #name, \
|
||||
col_reduce_small, \
|
||||
itype, otype, op, dim) \
|
||||
instantiate_kernel("col_reduce_longcolumn_" #dim "_reduce_" #name, \
|
||||
col_reduce_longcolumn, \
|
||||
itype, otype, op, dim)
|
||||
|
||||
#define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \
|
||||
instantiate_kernel("col_reduce_looped_" #dim "_" #bm "_" #bn "_reduce_" #name, \
|
||||
col_reduce_looped, \
|
||||
itype, otype, op, dim, bm, bn)
|
||||
|
||||
#define instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, bm, bn) \
|
||||
instantiate_kernel("col_reduce_2pass_" #dim "_" #bm "_" #bn "_reduce_" #name, \
|
||||
col_reduce_2pass, \
|
||||
itype, otype, op, dim, bm, bn)
|
||||
#define instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, bm, bn) \
|
||||
instantiate_kernel("col_reduce_2pass_" #dim "_" #bm "_" #bn "_reduce_" #name, \
|
||||
col_reduce_2pass, \
|
||||
itype, otype, op, uint, dim, bm, bn) \
|
||||
instantiate_kernel("col_reduce_2pass_large_" #dim "_" #bm "_" #bn "_reduce_" #name, \
|
||||
col_reduce_2pass, \
|
||||
itype, otype, op, size_t, dim, bm, bn)
|
||||
|
||||
#define instantiate_col_reduce_looped(name, itype, otype, op, dim) \
|
||||
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 32, 32) \
|
||||
instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, 32, 32)
|
||||
|
||||
#define instantiate_col_reduce_general(name, itype, otype, op) \
|
||||
instantiate_col_reduce_small(name, itype, otype, op, 0) \
|
||||
instantiate_col_reduce_small(name, itype, otype, op, 1) \
|
||||
instantiate_col_reduce_small(name, itype, otype, op, 2) \
|
||||
instantiate_col_reduce_small(name, itype, otype, op, 3) \
|
||||
instantiate_col_reduce_small(name, itype, otype, op, 4) \
|
||||
instantiate_col_reduce_looped(name, itype, otype, op, 0) \
|
||||
instantiate_col_reduce_small(name, itype, otype, op, 5) \
|
||||
instantiate_col_reduce_looped(name, itype, otype, op, 1) \
|
||||
instantiate_col_reduce_looped(name, itype, otype, op, 2) \
|
||||
instantiate_col_reduce_looped(name, itype, otype, op, 3) \
|
||||
instantiate_col_reduce_looped(name, itype, otype, op, 4)
|
||||
instantiate_col_reduce_looped(name, itype, otype, op, 5)
|
||||
|
||||
#define instantiate_same_col_reduce_helper(name, tname, type, op) \
|
||||
instantiate_col_reduce_general(name##tname, type, type, op<type>)
|
||||
#define instantiate_row_reduce_small(name, itype, otype, op, dim) \
|
||||
instantiate_kernel("row_reduce_small_" #dim "_reduce_" #name, \
|
||||
row_reduce_small, \
|
||||
itype, otype, op, uint, dim) \
|
||||
instantiate_kernel("row_reduce_small_large_" #dim "_reduce_" #name, \
|
||||
row_reduce_small, \
|
||||
itype, otype, op, size_t, dim)
|
||||
|
||||
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_64b)
|
||||
|
||||
instantiate_col_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_general, and, bool, And<bool>)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or<bool>)
|
||||
|
||||
#define instantiate_row_reduce_small(name, itype, otype, op, dim) \
|
||||
instantiate_kernel("row_reduce_small_" #dim "_reduce_" #name, \
|
||||
row_reduce_small, \
|
||||
itype, otype, op, dim)
|
||||
|
||||
#define instantiate_row_reduce_looped(name, itype, otype, op, dim) \
|
||||
instantiate_kernel("row_reduce_looped_" #dim "_reduce_" #name, \
|
||||
row_reduce_looped, \
|
||||
itype, otype, op, dim)
|
||||
#define instantiate_row_reduce_looped(name, itype, otype, op, dim) \
|
||||
instantiate_kernel("row_reduce_looped_" #dim "_reduce_" #name, \
|
||||
row_reduce_looped, \
|
||||
itype, otype, op, uint, dim) \
|
||||
instantiate_kernel("row_reduce_looped_large_" #dim "_reduce_" #name, \
|
||||
row_reduce_looped, \
|
||||
itype, otype, op, size_t, dim)
|
||||
|
||||
#define instantiate_row_reduce_general(name, itype, otype, op) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op, 0) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op, 1) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op, 2) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op, 3) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op, 4) \
|
||||
instantiate_row_reduce_looped(name, itype, otype, op, 0) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op, 5) \
|
||||
instantiate_row_reduce_looped(name, itype, otype, op, 1) \
|
||||
instantiate_row_reduce_looped(name, itype, otype, op, 2) \
|
||||
instantiate_row_reduce_looped(name, itype, otype, op, 3) \
|
||||
instantiate_row_reduce_looped(name, itype, otype, op, 4) \
|
||||
instantiate_row_reduce_looped(name, itype, otype, op, 5) \
|
||||
instantiate_kernel("row_reduce_simple_" #name, \
|
||||
row_reduce_simple, \
|
||||
itype, otype, op)
|
||||
|
||||
#define instantiate_same_row_reduce_helper(name, tname, type, op) \
|
||||
instantiate_row_reduce_general(name##tname, type, type, op<type>)
|
||||
#define instantiate_reduce_functions(name, tname, itype, otype, op) \
|
||||
instantiate_all_reduce(name##tname, itype, otype, op<otype>) \
|
||||
instantiate_row_reduce_general(name##tname, itype, otype, op<otype>) \
|
||||
instantiate_col_reduce_general(name##tname, itype, otype, op<otype>)
|
||||
|
||||
instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_64b)
|
||||
#define instantiate_and_or(name, op) \
|
||||
instantiate_reduce_functions(name, bool_, bool, bool, op) \
|
||||
instantiate_reduce_functions(name, int16, int16_t, bool, op) \
|
||||
instantiate_reduce_functions(name, int32, int32_t, bool, op) \
|
||||
instantiate_reduce_functions(name, int64, int64_t, bool, op)
|
||||
|
||||
instantiate_reduce_from_types(instantiate_row_reduce_general, and, bool, And<bool>)
|
||||
instantiate_reduce_from_types(instantiate_row_reduce_general, or, bool, Or<bool>)
|
||||
instantiate_and_or(and, And)
|
||||
instantiate_and_or(or, Or)
|
||||
|
||||
instantiate_row_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
#define instantiate_sum_prod(name, op) \
|
||||
instantiate_reduce_functions(name, int8, int8_t, int32_t, op) \
|
||||
instantiate_reduce_functions(name, int16, int16_t, int32_t, op) \
|
||||
instantiate_reduce_functions(name, int32, int32_t, int32_t, op) \
|
||||
instantiate_reduce_functions(name, int64, int64_t, int64_t, op) \
|
||||
instantiate_reduce_functions(name, float16, float16_t, float16_t, op) \
|
||||
instantiate_reduce_functions(name, bfloat16, bfloat16_t, bfloat16_t, op) \
|
||||
instantiate_reduce_functions(name, float32, float, float, op) \
|
||||
instantiate_reduce_functions(name, complex64, complex64_t, complex64_t, op)
|
||||
|
||||
instantiate_sum_prod(sum, Sum)
|
||||
instantiate_sum_prod(prod, Prod)
|
||||
|
||||
#define instantiate_min_max(name, op) \
|
||||
instantiate_reduce_functions(name, int8, int8_t, int8_t, op) \
|
||||
instantiate_reduce_functions(name, int16, int16_t, int16_t, op) \
|
||||
instantiate_reduce_functions(name, int32, int32_t, int32_t, op) \
|
||||
instantiate_reduce_functions(name, int64, int64_t, int64_t, op) \
|
||||
instantiate_reduce_functions(name, uint8, uint8_t, uint8_t, op) \
|
||||
instantiate_reduce_functions(name, uint16, uint16_t, uint16_t, op) \
|
||||
instantiate_reduce_functions(name, uint32, uint32_t, uint32_t, op) \
|
||||
instantiate_reduce_functions(name, uint64, uint64_t, uint64_t, op) \
|
||||
instantiate_reduce_functions(name, float16, float16_t, float16_t, op) \
|
||||
instantiate_reduce_functions(name, bfloat16, bfloat16_t, bfloat16_t, op) \
|
||||
instantiate_reduce_functions(name, float32, float, float, op) \
|
||||
instantiate_reduce_functions(name, complex64, complex64_t, complex64_t, op)
|
||||
|
||||
instantiate_min_max(min, Min)
|
||||
instantiate_min_max(max, Max)
|
||||
// clang-format on
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
typename IdxT = int64_t,
|
||||
int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void all_reduce(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
@@ -16,10 +21,10 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
threadgroup U shared_vals[simd_size];
|
||||
|
||||
U total = Op::init;
|
||||
int64_t start_idx = gid.y * row_size;
|
||||
int64_t actual_row =
|
||||
IdxT start_idx = gid.y * IdxT(row_size);
|
||||
IdxT actual_row =
|
||||
(start_idx + row_size <= in_size) ? row_size : in_size - start_idx;
|
||||
int64_t blocks = actual_row / (lsize.x * N_READS);
|
||||
IdxT blocks = actual_row / (lsize.x * N_READS);
|
||||
int extra = actual_row - blocks * (lsize.x * N_READS);
|
||||
extra -= lid.x * N_READS;
|
||||
start_idx += lid.x * N_READS;
|
||||
@@ -30,7 +35,7 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
extra = 0;
|
||||
}
|
||||
|
||||
for (int64_t b = 0; b < blocks; b++) {
|
||||
for (IdxT b = 0; b < blocks; b++) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
total = op(static_cast<U>(in[i]), total);
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
template <typename T, typename U, typename Op, int NDIMS>
|
||||
template <typename T, typename U, typename Op, typename IdxT, int NDIMS>
|
||||
[[kernel]] void col_reduce_small(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
@@ -19,7 +19,7 @@ template <typename T, typename U, typename Op, int NDIMS>
|
||||
uint3 lsize [[threads_per_threadgroup]]) {
|
||||
constexpr int n_reads = 4;
|
||||
Op op;
|
||||
looped_elem_to_loc<NDIMS> loop;
|
||||
LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
|
||||
const device T* row;
|
||||
|
||||
U totals[n_reads];
|
||||
@@ -27,20 +27,20 @@ template <typename T, typename U, typename Op, int NDIMS>
|
||||
totals[i] = Op::init;
|
||||
}
|
||||
|
||||
size_t column = size_t(gid.x) * lsize.x * n_reads + lid.x * n_reads;
|
||||
IdxT column = IdxT(gid.x) * lsize.x * n_reads + lid.x * n_reads;
|
||||
if (column >= reduction_stride) {
|
||||
return;
|
||||
}
|
||||
bool safe = column + n_reads <= reduction_stride;
|
||||
|
||||
size_t out_idx = gid.y + gsize.y * size_t(gid.z);
|
||||
size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
|
||||
IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);
|
||||
IdxT in_idx = elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim);
|
||||
in += in_idx + column;
|
||||
|
||||
size_t total_rows = non_col_reductions * reduction_size;
|
||||
IdxT total_rows = IdxT(non_col_reductions) * IdxT(reduction_size);
|
||||
loop.next(lid.y, reduce_shape, reduce_strides);
|
||||
for (size_t r = lid.y; r < total_rows; r += lsize.y) {
|
||||
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
|
||||
for (IdxT r = lid.y; r < total_rows; r += lsize.y) {
|
||||
row = in + loop.location();
|
||||
if (safe) {
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
totals[i] = op(static_cast<U>(row[i]), totals[i]);
|
||||
@@ -80,7 +80,7 @@ template <typename T, typename U, typename Op, int NDIMS>
|
||||
}
|
||||
|
||||
if (lid.y == 0) {
|
||||
out += out_idx * reduction_stride + column;
|
||||
out += out_idx * IdxT(reduction_stride) + column;
|
||||
if (safe) {
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
out[i] = totals[i];
|
||||
@@ -93,7 +93,7 @@ template <typename T, typename U, typename Op, int NDIMS>
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int NDIMS>
|
||||
template <typename T, typename U, typename Op, typename IdxT, int NDIMS>
|
||||
[[kernel]] void col_reduce_longcolumn(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
@@ -112,19 +112,19 @@ template <typename T, typename U, typename Op, int NDIMS>
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 lsize [[threads_per_threadgroup]]) {
|
||||
Op op;
|
||||
looped_elem_to_loc<NDIMS> loop;
|
||||
LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
|
||||
const device T* row;
|
||||
|
||||
size_t out_idx = gid.x + gsize.x * size_t(gid.y);
|
||||
size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
|
||||
IdxT out_idx = gid.x + gsize.x * IdxT(gid.y);
|
||||
IdxT in_idx = elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim);
|
||||
in += in_idx + lid.x;
|
||||
|
||||
U total = Op::init;
|
||||
size_t total_rows = non_col_reductions * reduction_size;
|
||||
IdxT total_rows = IdxT(non_col_reductions) * IdxT(reduction_size);
|
||||
loop.next(gid.z * lsize.y + lid.y, reduce_shape, reduce_strides);
|
||||
for (size_t r = gid.z * lsize.y + lid.y; r < total_rows;
|
||||
for (IdxT r = gid.z * lsize.y + lid.y; r < total_rows;
|
||||
r += lsize.y * gsize.z) {
|
||||
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
|
||||
row = in + loop.location();
|
||||
total = op(static_cast<U>(*row), total);
|
||||
loop.next(lsize.y * gsize.z, reduce_shape, reduce_strides);
|
||||
}
|
||||
@@ -136,7 +136,8 @@ template <typename T, typename U, typename Op, int NDIMS>
|
||||
for (uint i = 1; i < lsize.y; i++) {
|
||||
total = op(total, shared_vals[i * lsize.x + lid.x]);
|
||||
}
|
||||
out[gid.z * out_size + out_idx * reduction_stride + lid.x] = total;
|
||||
out[gid.z * IdxT(out_size) + out_idx * IdxT(reduction_stride) + lid.x] =
|
||||
total;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -151,7 +152,14 @@ template <typename T, typename U, typename Op, int NDIMS>
|
||||
* totals with a loop.
|
||||
* 7. Write them to the output
|
||||
*/
|
||||
template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
typename IdxT,
|
||||
int NDIMS,
|
||||
int BM,
|
||||
int BN>
|
||||
[[kernel]] void col_reduce_looped(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
@@ -176,7 +184,7 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
|
||||
|
||||
threadgroup U shared_vals[BN * BM];
|
||||
U totals[n_reads];
|
||||
looped_elem_to_loc<NDIMS> loop;
|
||||
LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
|
||||
const device T* row;
|
||||
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
@@ -185,17 +193,17 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
|
||||
|
||||
short lid = simd_group_id * simd_size + simd_lane_id;
|
||||
short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);
|
||||
size_t column = BN * gid.x + offset.x;
|
||||
IdxT column = BN * gid.x + offset.x;
|
||||
bool safe = column + n_reads <= reduction_stride;
|
||||
|
||||
size_t out_idx = gid.y + gsize.y * size_t(gid.z);
|
||||
size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
|
||||
IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);
|
||||
IdxT in_idx = elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim);
|
||||
in += in_idx + column;
|
||||
|
||||
size_t total = non_col_reductions * reduction_size;
|
||||
IdxT total = IdxT(non_col_reductions) * IdxT(reduction_size);
|
||||
loop.next(offset.y, reduce_shape, reduce_strides);
|
||||
for (size_t r = offset.y; r < total; r += BM) {
|
||||
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
|
||||
for (IdxT r = offset.y; r < total; r += BM) {
|
||||
row = in + loop.location();
|
||||
|
||||
if (safe) {
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
@@ -235,8 +243,8 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
|
||||
|
||||
// Write the output.
|
||||
if (simd_lane_id == 0) {
|
||||
size_t out_column = BN * gid.x + out_offset.x;
|
||||
out += out_idx * reduction_stride + out_column;
|
||||
IdxT out_column = BN * gid.x + out_offset.x;
|
||||
out += out_idx * IdxT(reduction_stride) + out_column;
|
||||
if (out_column + n_outputs <= reduction_stride) {
|
||||
for (int i = 0; i < n_outputs; i++) {
|
||||
out[i] = totals[i];
|
||||
@@ -269,7 +277,7 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
|
||||
|
||||
// Write the output.
|
||||
if (offset.y == 0) {
|
||||
out += out_idx * reduction_stride + column;
|
||||
out += out_idx * IdxT(reduction_stride) + column;
|
||||
if (safe) {
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
out[i] = totals[i];
|
||||
@@ -283,7 +291,14 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
typename IdxT,
|
||||
int NDIMS,
|
||||
int BM,
|
||||
int BN>
|
||||
[[kernel]] void col_reduce_2pass(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
@@ -312,7 +327,7 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
|
||||
|
||||
threadgroup U shared_vals[BN * BM];
|
||||
U totals[n_reads];
|
||||
looped_elem_to_loc<NDIMS> loop;
|
||||
LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
|
||||
const device T* row;
|
||||
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
@@ -321,20 +336,19 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
|
||||
|
||||
short lid = simd_group_id * simd_size + simd_lane_id;
|
||||
short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);
|
||||
size_t column = BN * gid.x + offset.x;
|
||||
IdxT column = BN * gid.x + offset.x;
|
||||
bool safe = column + n_reads <= reduction_stride;
|
||||
|
||||
size_t full_idx = gid.y + gsize.y * size_t(gid.z);
|
||||
size_t block_idx = full_idx / out_size;
|
||||
size_t out_idx = full_idx % out_size;
|
||||
size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
|
||||
IdxT full_idx = gid.y + gsize.y * IdxT(gid.z);
|
||||
IdxT block_idx = full_idx / IdxT(out_size);
|
||||
IdxT out_idx = full_idx % IdxT(out_size);
|
||||
IdxT in_idx = elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim);
|
||||
in += in_idx + column;
|
||||
|
||||
size_t total = non_col_reductions * reduction_size;
|
||||
IdxT total = IdxT(non_col_reductions) * IdxT(reduction_size);
|
||||
loop.next(offset.y + block_idx * BM, reduce_shape, reduce_strides);
|
||||
for (size_t r = offset.y + block_idx * BM; r < total;
|
||||
r += outer_blocks * BM) {
|
||||
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
|
||||
for (IdxT r = offset.y + block_idx * BM; r < total; r += outer_blocks * BM) {
|
||||
row = in + loop.location();
|
||||
|
||||
if (safe) {
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
@@ -369,8 +383,8 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
|
||||
|
||||
// Write the output.
|
||||
if (simd_lane_id == 0) {
|
||||
size_t out_column = BN * gid.x + out_offset.x;
|
||||
out += full_idx * reduction_stride + out_column;
|
||||
IdxT out_column = BN * gid.x + out_offset.x;
|
||||
out += full_idx * IdxT(reduction_stride) + out_column;
|
||||
if (out_column + n_outputs <= reduction_stride) {
|
||||
for (int i = 0; i < n_outputs; i++) {
|
||||
out[i] = totals[i];
|
||||
|
||||
@@ -193,6 +193,7 @@ template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
typename IdxT,
|
||||
int NDIMS,
|
||||
int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void row_reduce_small(
|
||||
@@ -214,20 +215,20 @@ template <
|
||||
Op op;
|
||||
|
||||
U total_val = Op::init;
|
||||
looped_elem_to_loc<NDIMS> loop;
|
||||
LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
|
||||
|
||||
// Precompute some row reduction numbers
|
||||
const device T* row;
|
||||
int blocks = row_size / N_READS;
|
||||
int extra = row_size % N_READS;
|
||||
int blocks = IdxT(row_size) / N_READS;
|
||||
int extra = IdxT(row_size) % N_READS;
|
||||
|
||||
if ((non_row_reductions < 32 && row_size <= 8) || non_row_reductions <= 8) {
|
||||
// Simple loop over non_row_reductions and reduce the row in the thread.
|
||||
size_t out_idx = tid.x + tsize.y * size_t(tid.y);
|
||||
in += elem_to_loc(out_idx, shape, strides, ndim);
|
||||
IdxT out_idx = tid.x + tsize.y * IdxT(tid.y);
|
||||
in += elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim);
|
||||
|
||||
for (uint r = 0; r < non_row_reductions; r++) {
|
||||
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
|
||||
row = in + loop.location();
|
||||
thread_reduce<T, U, Op, N_READS>(total_val, row, blocks, extra);
|
||||
loop.next(reduce_shape, reduce_strides);
|
||||
}
|
||||
@@ -236,13 +237,13 @@ template <
|
||||
} else {
|
||||
// Collaboratively reduce over non_row_reductions in the simdgroup. Each
|
||||
// thread reduces every 32nd row and then a simple simd reduce.
|
||||
size_t out_idx = gid.y + gsize.y * size_t(gid.z);
|
||||
in += elem_to_loc(out_idx, shape, strides, ndim);
|
||||
IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);
|
||||
in += elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim);
|
||||
|
||||
loop.next(simd_lane_id, reduce_shape, reduce_strides);
|
||||
|
||||
for (uint r = simd_lane_id; r < non_row_reductions; r += simd_size) {
|
||||
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
|
||||
row = in + loop.location();
|
||||
thread_reduce<T, U, Op, N_READS>(total_val, row, blocks, extra);
|
||||
loop.next(simd_size, reduce_shape, reduce_strides);
|
||||
}
|
||||
@@ -259,6 +260,7 @@ template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
typename IdxT = size_t,
|
||||
int N_READS = REDUCE_N_READS,
|
||||
int N_WRITES = REDUCE_N_WRITES>
|
||||
[[kernel]] void row_reduce_simple(
|
||||
@@ -277,15 +279,15 @@ template <
|
||||
U totals[N_WRITES];
|
||||
|
||||
// Move to the row
|
||||
size_t out_idx = N_WRITES * (gid.y + gsize.y * size_t(gid.z));
|
||||
IdxT out_idx = N_WRITES * (gid.y + gsize.y * IdxT(gid.z));
|
||||
if (out_idx + N_WRITES > out_size) {
|
||||
out_idx = out_size - N_WRITES;
|
||||
}
|
||||
in += out_idx * reduction_size;
|
||||
in += out_idx * IdxT(reduction_size);
|
||||
out += out_idx;
|
||||
|
||||
// Each thread reduces across the row
|
||||
int blocks = reduction_size / (lsize.x * N_READS);
|
||||
int blocks = IdxT(reduction_size) / (lsize.x * N_READS);
|
||||
int extra = reduction_size - blocks * (lsize.x * N_READS);
|
||||
per_thread_row_reduce<T, U, Op, N_READS, N_WRITES>(
|
||||
totals, in, reduction_size, blocks, extra, lsize.x, lid.x);
|
||||
@@ -306,6 +308,7 @@ template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
typename IdxT,
|
||||
int NDIMS,
|
||||
int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void row_reduce_looped(
|
||||
@@ -330,19 +333,20 @@ template <
|
||||
threadgroup U shared_vals[simd_size];
|
||||
U total = Op::init;
|
||||
|
||||
size_t out_idx = gid.y + gsize.y * size_t(gid.z);
|
||||
IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);
|
||||
|
||||
// lid.x * N_READS breaks the per_thread_row_reduce interface a bit. Maybe it
|
||||
// needs a small refactor.
|
||||
in += elem_to_loc(out_idx, shape, strides, ndim) + lid.x * N_READS;
|
||||
in += elem_to_loc<size_t, IdxT>(out_idx, shape, strides, ndim) +
|
||||
lid.x * N_READS;
|
||||
|
||||
looped_elem_to_loc<NDIMS> loop;
|
||||
LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
|
||||
const device T* row;
|
||||
int blocks = row_size / (lsize.x * N_READS);
|
||||
int blocks = IdxT(row_size) / (lsize.x * N_READS);
|
||||
int extra = row_size - blocks * (lsize.x * N_READS);
|
||||
|
||||
for (size_t i = 0; i < non_row_reductions; i++) {
|
||||
row = in + loop.location(i, reduce_shape, reduce_strides, reduce_ndim);
|
||||
for (IdxT i = 0; i < non_row_reductions; i++) {
|
||||
row = in + loop.location();
|
||||
|
||||
// Each thread reduces across the row
|
||||
U row_total;
|
||||
|
||||
@@ -204,16 +204,21 @@ METAL_FUNC vec<IdxT, 3> elem_to_loc_3_nd(
|
||||
// Elem to loc in a loop utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int dim, typename offset_t = size_t>
|
||||
struct looped_elem_to_loc {
|
||||
looped_elem_to_loc<dim - 1, offset_t> inner_looper;
|
||||
offset_t offset{0};
|
||||
template <int DIM, typename OffsetT = size_t, bool General = true>
|
||||
struct LoopedElemToLoc {
|
||||
int dim;
|
||||
LoopedElemToLoc<DIM - 1, OffsetT, General> inner_looper;
|
||||
OffsetT offset{0};
|
||||
int index{0};
|
||||
|
||||
void next(const constant int* shape, const constant size_t* strides) {
|
||||
index++;
|
||||
offset += strides[dim - 1];
|
||||
LoopedElemToLoc(int dim) : dim(dim), inner_looper(dim - 1) {}
|
||||
|
||||
void next(const constant int* shape, const constant size_t* strides) {
|
||||
if (dim == 0) {
|
||||
return;
|
||||
}
|
||||
index++;
|
||||
offset += OffsetT(strides[dim - 1]);
|
||||
if (index >= shape[dim - 1]) {
|
||||
index = 0;
|
||||
inner_looper.next(shape, strides);
|
||||
@@ -222,13 +227,21 @@ struct looped_elem_to_loc {
|
||||
}
|
||||
|
||||
void next(int n, const constant int* shape, const constant size_t* strides) {
|
||||
if (dim == 0) {
|
||||
return;
|
||||
}
|
||||
index += n;
|
||||
offset += n * strides[dim - 1];
|
||||
offset += n * OffsetT(strides[dim - 1]);
|
||||
|
||||
if (index >= shape[dim - 1]) {
|
||||
int extra = index - shape[dim - 1];
|
||||
if (extra >= shape[dim - 1]) {
|
||||
inner_looper.next(1 + extra / shape[dim - 1], shape, strides);
|
||||
extra = extra % shape[dim - 1];
|
||||
} else {
|
||||
inner_looper.next(shape, strides);
|
||||
}
|
||||
index = 0;
|
||||
inner_looper.next(shape, strides);
|
||||
offset = inner_looper.offset;
|
||||
if (extra > 0) {
|
||||
next(extra, shape, strides);
|
||||
@@ -236,44 +249,61 @@ struct looped_elem_to_loc {
|
||||
}
|
||||
}
|
||||
|
||||
offset_t
|
||||
location(offset_t, const constant int*, const constant size_t*, int) {
|
||||
OffsetT location() {
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename offset_t>
|
||||
struct looped_elem_to_loc<1, offset_t> {
|
||||
offset_t offset{0};
|
||||
template <typename OffsetT>
|
||||
struct LoopedElemToLoc<1, OffsetT, true> {
|
||||
int dim;
|
||||
OffsetT offset{0};
|
||||
uint index{0};
|
||||
|
||||
LoopedElemToLoc(int dim) : dim(dim) {}
|
||||
|
||||
void next(const constant int* shape, const constant size_t* strides) {
|
||||
index++;
|
||||
if (dim > 1) {
|
||||
offset = elem_to_loc<size_t, OffsetT>(index, shape, strides, dim);
|
||||
} else {
|
||||
offset += OffsetT(strides[0]);
|
||||
}
|
||||
}
|
||||
|
||||
void next(int n, const constant int* shape, const constant size_t* strides) {
|
||||
index += n;
|
||||
if (dim > 1) {
|
||||
offset = elem_to_loc<size_t, OffsetT>(index, shape, strides, dim);
|
||||
} else {
|
||||
offset = index * OffsetT(strides[0]);
|
||||
}
|
||||
}
|
||||
|
||||
OffsetT location() {
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OffsetT>
|
||||
struct LoopedElemToLoc<1, OffsetT, false> {
|
||||
OffsetT offset{0};
|
||||
|
||||
LoopedElemToLoc(int) {}
|
||||
|
||||
void next(const constant int*, const constant size_t* strides) {
|
||||
offset += strides[0];
|
||||
offset += OffsetT(strides[0]);
|
||||
}
|
||||
|
||||
void next(int n, const constant int*, const constant size_t* strides) {
|
||||
offset += n * strides[0];
|
||||
offset += n * OffsetT(strides[0]);
|
||||
}
|
||||
|
||||
offset_t
|
||||
location(offset_t, const constant int*, const constant size_t*, int) {
|
||||
OffsetT location() {
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename offset_t>
|
||||
struct looped_elem_to_loc<0, offset_t> {
|
||||
void next(const constant int*, const constant size_t*) {}
|
||||
void next(int, const constant int*, const constant size_t*) {}
|
||||
|
||||
offset_t location(
|
||||
offset_t idx,
|
||||
const constant int* shape,
|
||||
const constant size_t* strides,
|
||||
int ndim) {
|
||||
return elem_to_loc(idx, shape, strides, ndim);
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Calculation utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
Reference in New Issue
Block a user