mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Reductions update (#1351)
This commit is contained in:
parent
76f275b4df
commit
248431eb3c
@ -144,6 +144,13 @@ def reduction(op, axis, x):
|
|||||||
mx.eval(ys)
|
mx.eval(ys)
|
||||||
|
|
||||||
|
|
||||||
|
def sum_and_add(axis, x, y):
|
||||||
|
z = x.sum(axis=axis, keepdims=True)
|
||||||
|
for i in range(50):
|
||||||
|
z = (z + y).sum(axis=axis, keepdims=True)
|
||||||
|
mx.eval(z)
|
||||||
|
|
||||||
|
|
||||||
def softmax(axis, x):
|
def softmax(axis, x):
|
||||||
ys = []
|
ys = []
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
@ -505,5 +512,8 @@ if __name__ == "__main__":
|
|||||||
elif args.benchmark == "selu":
|
elif args.benchmark == "selu":
|
||||||
print(bench(selu, x))
|
print(bench(selu, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "sum_and_add":
|
||||||
|
print(bench(sum_and_add, axis, *xs))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown benchmark")
|
raise ValueError("Unknown benchmark")
|
||||||
|
@ -319,16 +319,18 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
|
|||||||
MTL::ComputePipelineState* get_reduce_init_kernel(
|
MTL::ComputePipelineState* get_reduce_init_kernel(
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
|
const std::string& func_name,
|
||||||
|
const std::string& op_name,
|
||||||
const array& out) {
|
const array& out) {
|
||||||
auto lib = d.get_library(kernel_name, [&]() {
|
auto lib = d.get_library(kernel_name, [&]() {
|
||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
std::string op_type = op_name(out);
|
std::string op_type = op_name;
|
||||||
op_type[0] = std::toupper(op_name(out)[0]);
|
op_type[0] = std::toupper(op_name[0]);
|
||||||
auto out_type = get_type_string(out.dtype());
|
auto out_type = get_type_string(out.dtype());
|
||||||
std::string op = op_type + "<" + out_type + ">";
|
std::string op = op_type + "<" + out_type + ">";
|
||||||
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce();
|
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce();
|
||||||
kernel_source << get_template_definition(
|
kernel_source << get_template_definition(
|
||||||
kernel_name, "init_reduce", out_type, op);
|
kernel_name, func_name, out_type, op);
|
||||||
return kernel_source.str();
|
return kernel_source.str();
|
||||||
});
|
});
|
||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
|
@ -79,6 +79,8 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
|
|||||||
MTL::ComputePipelineState* get_reduce_init_kernel(
|
MTL::ComputePipelineState* get_reduce_init_kernel(
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
|
const std::string& func_name,
|
||||||
|
const std::string& op_name,
|
||||||
const array& out);
|
const array& out);
|
||||||
|
|
||||||
MTL::ComputePipelineState* get_reduce_kernel(
|
MTL::ComputePipelineState* get_reduce_kernel(
|
||||||
|
@ -116,6 +116,9 @@ instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
|||||||
#define instantiate_col_reduce_small(name, itype, otype, op, dim) \
|
#define instantiate_col_reduce_small(name, itype, otype, op, dim) \
|
||||||
instantiate_kernel("col_reduce_small_" #dim "_reduce_" #name, \
|
instantiate_kernel("col_reduce_small_" #dim "_reduce_" #name, \
|
||||||
col_reduce_small, \
|
col_reduce_small, \
|
||||||
|
itype, otype, op, dim) \
|
||||||
|
instantiate_kernel("col_reduce_longcolumn_" #dim "_reduce_" #name, \
|
||||||
|
col_reduce_longcolumn, \
|
||||||
itype, otype, op, dim)
|
itype, otype, op, dim)
|
||||||
|
|
||||||
#define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \
|
#define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \
|
||||||
@ -123,9 +126,14 @@ instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
|||||||
col_reduce_looped, \
|
col_reduce_looped, \
|
||||||
itype, otype, op, dim, bm, bn)
|
itype, otype, op, dim, bm, bn)
|
||||||
|
|
||||||
|
#define instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, bm, bn) \
|
||||||
|
instantiate_kernel("col_reduce_2pass_" #dim "_" #bm "_" #bn "_reduce_" #name, \
|
||||||
|
col_reduce_2pass, \
|
||||||
|
itype, otype, op, dim, bm, bn)
|
||||||
|
|
||||||
#define instantiate_col_reduce_looped(name, itype, otype, op, dim) \
|
#define instantiate_col_reduce_looped(name, itype, otype, op, dim) \
|
||||||
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 8, 128) \
|
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 32, 32) \
|
||||||
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 32, 32)
|
instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, 32, 32)
|
||||||
|
|
||||||
#define instantiate_col_reduce_general(name, itype, otype, op) \
|
#define instantiate_col_reduce_general(name, itype, otype, op) \
|
||||||
instantiate_col_reduce_small(name, itype, otype, op, 0) \
|
instantiate_col_reduce_small(name, itype, otype, op, 0) \
|
||||||
|
@ -1,11 +1,6 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
template <
|
template <typename T, typename U, typename Op, int NDIMS>
|
||||||
typename T,
|
|
||||||
typename U,
|
|
||||||
typename Op,
|
|
||||||
int NDIMS,
|
|
||||||
int N_READS = REDUCE_N_READS>
|
|
||||||
[[kernel]] void col_reduce_small(
|
[[kernel]] void col_reduce_small(
|
||||||
const device T* in [[buffer(0)]],
|
const device T* in [[buffer(0)]],
|
||||||
device U* out [[buffer(1)]],
|
device U* out [[buffer(1)]],
|
||||||
@ -20,170 +15,128 @@ template <
|
|||||||
const constant size_t& non_col_reductions [[buffer(10)]],
|
const constant size_t& non_col_reductions [[buffer(10)]],
|
||||||
uint3 gid [[threadgroup_position_in_grid]],
|
uint3 gid [[threadgroup_position_in_grid]],
|
||||||
uint3 gsize [[threadgroups_per_grid]],
|
uint3 gsize [[threadgroups_per_grid]],
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
uint3 lsize [[threads_per_threadgroup]]) {
|
||||||
uint3 tid [[thread_position_in_grid]],
|
constexpr int n_reads = 4;
|
||||||
uint3 tsize [[threads_per_grid]]) {
|
|
||||||
Op op;
|
Op op;
|
||||||
looped_elem_to_loc<NDIMS> loop;
|
looped_elem_to_loc<NDIMS> loop;
|
||||||
const device T* row;
|
const device T* row;
|
||||||
|
|
||||||
// Case 1: Small row small column
|
U totals[n_reads];
|
||||||
if (reduction_size * non_col_reductions < 64 && reduction_stride < 32) {
|
for (int i = 0; i < n_reads; i++) {
|
||||||
U totals[31];
|
|
||||||
for (int i = 0; i < 31; i++) {
|
|
||||||
totals[i] = Op::init;
|
totals[i] = Op::init;
|
||||||
}
|
}
|
||||||
|
|
||||||
short stride = reduction_stride;
|
size_t column = size_t(gid.x) * lsize.x * n_reads + lid.x * n_reads;
|
||||||
short size = reduction_size;
|
if (column >= reduction_stride) {
|
||||||
short blocks = stride / N_READS;
|
return;
|
||||||
short extra = stride - blocks * N_READS;
|
|
||||||
|
|
||||||
size_t out_idx = tid.x + tsize.y * size_t(tid.y);
|
|
||||||
in += elem_to_loc(out_idx, shape, strides, ndim);
|
|
||||||
|
|
||||||
for (uint r = 0; r < non_col_reductions; r++) {
|
|
||||||
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
|
|
||||||
|
|
||||||
for (short i = 0; i < size; i++) {
|
|
||||||
for (short j = 0; j < blocks; j++) {
|
|
||||||
for (short k = 0; k < N_READS; k++) {
|
|
||||||
totals[j * N_READS + k] =
|
|
||||||
op(totals[j * N_READS + k],
|
|
||||||
static_cast<U>(row[i * stride + j * N_READS + k]));
|
|
||||||
}
|
}
|
||||||
}
|
bool safe = column + n_reads <= reduction_stride;
|
||||||
for (short k = 0; k < extra; k++) {
|
|
||||||
totals[blocks * N_READS + k] =
|
|
||||||
op(totals[blocks * N_READS + k],
|
|
||||||
static_cast<U>(row[i * stride + blocks * N_READS + k]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
loop.next(reduce_shape, reduce_strides);
|
|
||||||
}
|
|
||||||
out += out_idx * reduction_stride;
|
|
||||||
for (short j = 0; j < stride; j++) {
|
|
||||||
out[j] = totals[j];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Case 2: Long row small column
|
|
||||||
else if (reduction_size * non_col_reductions < 32) {
|
|
||||||
U totals[N_READS];
|
|
||||||
for (int i = 0; i < N_READS; i++) {
|
|
||||||
totals[i] = Op::init;
|
|
||||||
}
|
|
||||||
|
|
||||||
short size = reduction_size;
|
|
||||||
size_t offset = size_t(tid.x) * N_READS;
|
|
||||||
bool safe = offset + N_READS <= reduction_stride;
|
|
||||||
short extra = reduction_stride - offset;
|
|
||||||
|
|
||||||
size_t out_idx = tid.y + tsize.z * size_t(tid.z);
|
|
||||||
in += elem_to_loc(out_idx, shape, strides, ndim) + offset;
|
|
||||||
|
|
||||||
for (uint r = 0; r < non_col_reductions; r++) {
|
|
||||||
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
|
|
||||||
|
|
||||||
if (safe) {
|
|
||||||
for (short i = 0; i < size; i++) {
|
|
||||||
for (short j = 0; j < N_READS; j++) {
|
|
||||||
totals[j] =
|
|
||||||
op(static_cast<U>(row[i * reduction_stride + j]), totals[j]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (short i = 0; i < size; i++) {
|
|
||||||
for (short j = 0; j < extra; j++) {
|
|
||||||
totals[j] =
|
|
||||||
op(static_cast<U>(row[i * reduction_stride + j]), totals[j]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
loop.next(reduce_shape, reduce_strides);
|
|
||||||
}
|
|
||||||
out += out_idx * reduction_stride + offset;
|
|
||||||
if (safe) {
|
|
||||||
for (short i = 0; i < N_READS; i++) {
|
|
||||||
out[i] = totals[i];
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (short i = 0; i < extra; i++) {
|
|
||||||
out[i] = totals[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Case 3: Long row medium column
|
|
||||||
else {
|
|
||||||
threadgroup U shared_vals[1024];
|
|
||||||
U totals[N_READS];
|
|
||||||
for (int i = 0; i < N_READS; i++) {
|
|
||||||
totals[i] = Op::init;
|
|
||||||
}
|
|
||||||
|
|
||||||
short stride = reduction_stride;
|
|
||||||
short lid = simd_group_id * simd_size + simd_lane_id;
|
|
||||||
short2 tile((stride + N_READS - 1) / N_READS, 32);
|
|
||||||
short2 offset((lid % tile.x) * N_READS, lid / tile.x);
|
|
||||||
short sm_stride = tile.x * N_READS;
|
|
||||||
bool safe = offset.x + N_READS <= stride;
|
|
||||||
|
|
||||||
size_t out_idx = gid.y + gsize.y * size_t(gid.z);
|
size_t out_idx = gid.y + gsize.y * size_t(gid.z);
|
||||||
in += elem_to_loc(out_idx, shape, strides, ndim) + offset.x;
|
size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
|
||||||
|
in += in_idx + column;
|
||||||
|
|
||||||
// Read cooperatively and contiguously and aggregate the partial results.
|
size_t total_rows = non_col_reductions * reduction_size;
|
||||||
size_t total = non_col_reductions * reduction_size;
|
loop.next(lid.y, reduce_shape, reduce_strides);
|
||||||
loop.next(offset.y, reduce_shape, reduce_strides);
|
for (size_t r = lid.y; r < total_rows; r += lsize.y) {
|
||||||
for (size_t r = offset.y; r < total; r += simd_size) {
|
|
||||||
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
|
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
|
||||||
|
|
||||||
if (safe) {
|
if (safe) {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < n_reads; i++) {
|
||||||
totals[i] = op(static_cast<U>(row[i]), totals[i]);
|
totals[i] = op(static_cast<U>(row[i]), totals[i]);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
U vals[N_READS];
|
U vals[n_reads];
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < n_reads; i++) {
|
||||||
vals[i] = (offset.x + i < stride) ? static_cast<U>(row[i]) : op.init;
|
vals[i] =
|
||||||
|
(column + i < reduction_stride) ? static_cast<U>(row[i]) : op.init;
|
||||||
}
|
}
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < n_reads; i++) {
|
||||||
totals[i] = op(vals[i], totals[i]);
|
totals[i] = op(vals[i], totals[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
loop.next(lsize.y, reduce_shape, reduce_strides);
|
||||||
loop.next(simd_size, reduce_shape, reduce_strides);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Each thread holds N_READS partial results but the simdgroups are not
|
if (lsize.y > 1) {
|
||||||
// aligned to do the reduction across the simdgroup so we write our results
|
// lsize.y should be <= 8
|
||||||
// in the shared memory and read them back according to the simdgroup.
|
threadgroup U shared_vals[32 * 8 * n_reads];
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < n_reads; i++) {
|
||||||
shared_vals[offset.y * sm_stride + offset.x + i] = totals[i];
|
shared_vals[lid.y * lsize.x * n_reads + lid.x * n_reads + i] = totals[i];
|
||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
for (int i = 0; i < N_READS; i++) {
|
if (lid.y == 0) {
|
||||||
totals[i] = op.simd_reduce(
|
for (int i = 0; i < n_reads; i++) {
|
||||||
shared_vals[simd_lane_id * sm_stride + simd_group_id * N_READS + i]);
|
totals[i] = shared_vals[lid.x * n_reads + i];
|
||||||
|
}
|
||||||
|
for (uint j = 1; j < lsize.y; j++) {
|
||||||
|
for (int i = 0; i < n_reads; i++) {
|
||||||
|
totals[i] =
|
||||||
|
op(shared_vals[j * lsize.x * n_reads + lid.x * n_reads + i],
|
||||||
|
totals[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write the output.
|
if (lid.y == 0) {
|
||||||
if (simd_lane_id == 0) {
|
|
||||||
short column = simd_group_id * N_READS;
|
|
||||||
out += out_idx * reduction_stride + column;
|
out += out_idx * reduction_stride + column;
|
||||||
if (column + N_READS <= stride) {
|
if (safe) {
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < n_reads; i++) {
|
||||||
out[i] = totals[i];
|
out[i] = totals[i];
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int i = 0; column + i < stride; i++) {
|
for (int i = 0; column + i < reduction_stride; i++) {
|
||||||
out[i] = totals[i];
|
out[i] = totals[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op, int NDIMS>
|
||||||
|
[[kernel]] void col_reduce_longcolumn(
|
||||||
|
const device T* in [[buffer(0)]],
|
||||||
|
device U* out [[buffer(1)]],
|
||||||
|
const constant size_t& reduction_size [[buffer(2)]],
|
||||||
|
const constant size_t& reduction_stride [[buffer(3)]],
|
||||||
|
const constant int* shape [[buffer(4)]],
|
||||||
|
const constant size_t* strides [[buffer(5)]],
|
||||||
|
const constant int& ndim [[buffer(6)]],
|
||||||
|
const constant int* reduce_shape [[buffer(7)]],
|
||||||
|
const constant size_t* reduce_strides [[buffer(8)]],
|
||||||
|
const constant int& reduce_ndim [[buffer(9)]],
|
||||||
|
const constant size_t& non_col_reductions [[buffer(10)]],
|
||||||
|
const constant size_t& out_size [[buffer(11)]],
|
||||||
|
uint3 gid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 gsize [[threadgroups_per_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
|
uint3 lsize [[threads_per_threadgroup]]) {
|
||||||
|
Op op;
|
||||||
|
looped_elem_to_loc<NDIMS> loop;
|
||||||
|
const device T* row;
|
||||||
|
|
||||||
|
size_t out_idx = gid.x + gsize.x * size_t(gid.y);
|
||||||
|
size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
|
||||||
|
in += in_idx + lid.x;
|
||||||
|
|
||||||
|
U total = Op::init;
|
||||||
|
size_t total_rows = non_col_reductions * reduction_size;
|
||||||
|
loop.next(gid.z * lsize.y + lid.y, reduce_shape, reduce_strides);
|
||||||
|
for (size_t r = gid.z * lsize.y + lid.y; r < total_rows;
|
||||||
|
r += lsize.y * gsize.z) {
|
||||||
|
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
|
||||||
|
total = op(static_cast<U>(*row), total);
|
||||||
|
loop.next(lsize.y * gsize.z, reduce_shape, reduce_strides);
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup U shared_vals[32 * 32];
|
||||||
|
shared_vals[lid.y * lsize.x + lid.x] = total;
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
if (lid.y == 0) {
|
||||||
|
for (uint i = 1; i < lsize.y; i++) {
|
||||||
|
total = op(total, shared_vals[i * lsize.x + lid.x]);
|
||||||
|
}
|
||||||
|
out[gid.z * out_size + out_idx * reduction_stride + lid.x] = total;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -216,7 +169,7 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
|
|||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||||
Op op;
|
Op op;
|
||||||
constexpr int n_simdgroups = 4;
|
constexpr int n_simdgroups = 8;
|
||||||
constexpr short tgp_size = n_simdgroups * simd_size;
|
constexpr short tgp_size = n_simdgroups * simd_size;
|
||||||
constexpr short n_reads = (BM * BN) / tgp_size;
|
constexpr short n_reads = (BM * BN) / tgp_size;
|
||||||
constexpr short n_read_blocks = BN / n_reads;
|
constexpr short n_read_blocks = BN / n_reads;
|
||||||
@ -329,3 +282,103 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
|
||||||
|
[[kernel]] void col_reduce_2pass(
|
||||||
|
const device T* in [[buffer(0)]],
|
||||||
|
device U* out [[buffer(1)]],
|
||||||
|
const constant size_t& reduction_size [[buffer(2)]],
|
||||||
|
const constant size_t& reduction_stride [[buffer(3)]],
|
||||||
|
const constant int* shape [[buffer(4)]],
|
||||||
|
const constant size_t* strides [[buffer(5)]],
|
||||||
|
const constant int& ndim [[buffer(6)]],
|
||||||
|
const constant int* reduce_shape [[buffer(7)]],
|
||||||
|
const constant size_t* reduce_strides [[buffer(8)]],
|
||||||
|
const constant int& reduce_ndim [[buffer(9)]],
|
||||||
|
const constant size_t& non_col_reductions [[buffer(10)]],
|
||||||
|
const constant size_t& out_size [[buffer(11)]],
|
||||||
|
uint3 gid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 gsize [[threadgroups_per_grid]],
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||||
|
Op op;
|
||||||
|
constexpr int n_simdgroups = 8;
|
||||||
|
constexpr short tgp_size = n_simdgroups * simd_size;
|
||||||
|
constexpr short n_reads = (BM * BN) / tgp_size;
|
||||||
|
constexpr short n_read_blocks = BN / n_reads;
|
||||||
|
constexpr int n_outputs = BN / n_simdgroups;
|
||||||
|
constexpr short outer_blocks = 32;
|
||||||
|
static_assert(BM == 32, "BM should be equal to 32");
|
||||||
|
|
||||||
|
threadgroup U shared_vals[BN * BM];
|
||||||
|
U totals[n_reads];
|
||||||
|
looped_elem_to_loc<NDIMS> loop;
|
||||||
|
const device T* row;
|
||||||
|
|
||||||
|
for (int i = 0; i < n_reads; i++) {
|
||||||
|
totals[i] = Op::init;
|
||||||
|
}
|
||||||
|
|
||||||
|
short lid = simd_group_id * simd_size + simd_lane_id;
|
||||||
|
short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);
|
||||||
|
size_t column = BN * gid.x + offset.x;
|
||||||
|
bool safe = column + n_reads <= reduction_stride;
|
||||||
|
|
||||||
|
size_t full_idx = gid.y + gsize.y * size_t(gid.z);
|
||||||
|
size_t block_idx = full_idx / out_size;
|
||||||
|
size_t out_idx = full_idx % out_size;
|
||||||
|
size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
|
||||||
|
in += in_idx + column;
|
||||||
|
|
||||||
|
size_t total = non_col_reductions * reduction_size;
|
||||||
|
loop.next(offset.y + block_idx * BM, reduce_shape, reduce_strides);
|
||||||
|
for (size_t r = offset.y + block_idx * BM; r < total;
|
||||||
|
r += outer_blocks * BM) {
|
||||||
|
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
|
||||||
|
|
||||||
|
if (safe) {
|
||||||
|
for (int i = 0; i < n_reads; i++) {
|
||||||
|
totals[i] = op(static_cast<U>(row[i]), totals[i]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
U vals[n_reads];
|
||||||
|
for (int i = 0; i < n_reads; i++) {
|
||||||
|
vals[i] =
|
||||||
|
(column + i < reduction_stride) ? static_cast<U>(row[i]) : op.init;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < n_reads; i++) {
|
||||||
|
totals[i] = op(vals[i], totals[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
loop.next(outer_blocks * BM, reduce_shape, reduce_strides);
|
||||||
|
}
|
||||||
|
|
||||||
|
// We can use a simd reduction to accumulate across BM so each thread writes
|
||||||
|
// the partial output to SM and then each simdgroup does BN / n_simdgroups
|
||||||
|
// accumulations.
|
||||||
|
for (int i = 0; i < n_reads; i++) {
|
||||||
|
shared_vals[offset.y * BN + offset.x + i] = totals[i];
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
short2 out_offset(simd_group_id * n_outputs, simd_lane_id);
|
||||||
|
for (int i = 0; i < n_outputs; i++) {
|
||||||
|
totals[i] =
|
||||||
|
op.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write the output.
|
||||||
|
if (simd_lane_id == 0) {
|
||||||
|
size_t out_column = BN * gid.x + out_offset.x;
|
||||||
|
out += full_idx * reduction_stride + out_column;
|
||||||
|
if (out_column + n_outputs <= reduction_stride) {
|
||||||
|
for (int i = 0; i < n_outputs; i++) {
|
||||||
|
out[i] = totals[i];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; out_column + i < reduction_stride; i++) {
|
||||||
|
out[i] = totals[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -97,6 +97,8 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
|
|||||||
MTL::ComputePipelineState* get_reduce_init_kernel(
|
MTL::ComputePipelineState* get_reduce_init_kernel(
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
|
const std::string&,
|
||||||
|
const std::string&,
|
||||||
const array&) {
|
const array&) {
|
||||||
return d.get_kernel(kernel_name);
|
return d.get_kernel(kernel_name);
|
||||||
}
|
}
|
||||||
|
@ -141,6 +141,20 @@ struct ColReduceArgs {
|
|||||||
ndim = shape.size();
|
ndim = shape.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create the col reduce arguments for reducing the 1st axis of the row
|
||||||
|
* contiguous intermediate array.
|
||||||
|
*/
|
||||||
|
ColReduceArgs(const array& intermediate) {
|
||||||
|
assert(intermediate.flags().row_contiguous);
|
||||||
|
|
||||||
|
reduction_size = intermediate.shape(0);
|
||||||
|
reduction_stride = intermediate.size() / reduction_size;
|
||||||
|
non_col_reductions = 1;
|
||||||
|
reduce_ndim = 0;
|
||||||
|
ndim = 0;
|
||||||
|
}
|
||||||
|
|
||||||
void encode(CommandEncoder& compute_encoder) {
|
void encode(CommandEncoder& compute_encoder) {
|
||||||
// Push 0s to avoid encoding empty vectors.
|
// Push 0s to avoid encoding empty vectors.
|
||||||
if (reduce_ndim == 0) {
|
if (reduce_ndim == 0) {
|
||||||
@ -231,8 +245,10 @@ void init_reduce(
|
|||||||
CommandEncoder& compute_encoder,
|
CommandEncoder& compute_encoder,
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
auto kernel = get_reduce_init_kernel(
|
std::ostringstream kname;
|
||||||
d, "init_reduce_" + op_name + type_to_name(out), out);
|
const std::string func_name = "init_reduce";
|
||||||
|
kname << func_name << "_" << op_name << type_to_name(out);
|
||||||
|
auto kernel = get_reduce_init_kernel(d, kname.str(), func_name, op_name, out);
|
||||||
size_t nthreads = out.size();
|
size_t nthreads = out.size();
|
||||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
@ -251,8 +267,7 @@ void all_reduce_dispatch(
|
|||||||
const std::string& op_name,
|
const std::string& op_name,
|
||||||
CommandEncoder& compute_encoder,
|
CommandEncoder& compute_encoder,
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const Stream& s,
|
const Stream& s) {
|
||||||
std::vector<array>& copies) {
|
|
||||||
// Set the kernel
|
// Set the kernel
|
||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
const std::string func_name = "all_reduce";
|
const std::string func_name = "all_reduce";
|
||||||
@ -293,7 +308,7 @@ void all_reduce_dispatch(
|
|||||||
// Allocate an intermediate tensor to hold results if needed
|
// Allocate an intermediate tensor to hold results if needed
|
||||||
array intermediate({n_rows}, out.dtype(), nullptr, {});
|
array intermediate({n_rows}, out.dtype(), nullptr, {});
|
||||||
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
|
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
|
||||||
copies.push_back(intermediate);
|
d.add_temporary(intermediate, s.index);
|
||||||
|
|
||||||
// 1st pass
|
// 1st pass
|
||||||
size_t row_size = (in_size + n_rows - 1) / n_rows;
|
size_t row_size = (in_size + n_rows - 1) / n_rows;
|
||||||
@ -469,39 +484,11 @@ void strided_reduce_small(
|
|||||||
// Figure out the grid dims
|
// Figure out the grid dims
|
||||||
MTL::Size grid_dims, group_dims;
|
MTL::Size grid_dims, group_dims;
|
||||||
|
|
||||||
// Case 1: Small row small column
|
// Prepare the arguments for the kernel
|
||||||
if (args.reduction_size * args.non_col_reductions < 64 &&
|
|
||||||
args.reduction_stride < 32) {
|
|
||||||
grid_dims = output_grid_for_col_reduce(out, args);
|
|
||||||
int threadgroup_size = (grid_dims.width > 128) ? 128 : grid_dims.width;
|
|
||||||
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Case 2: Long row small column
|
|
||||||
else if (args.reduction_size * args.non_col_reductions < 32) {
|
|
||||||
auto out_grid_dims = output_grid_for_col_reduce(out, args);
|
|
||||||
int threads_x =
|
|
||||||
(args.reduction_stride + REDUCE_N_READS - 1) / REDUCE_N_READS;
|
|
||||||
int threadgroup_x = std::min(threads_x, 128);
|
|
||||||
grid_dims = MTL::Size(threads_x, out_grid_dims.width, out_grid_dims.height);
|
|
||||||
group_dims = MTL::Size(threadgroup_x, 1, 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Case 3: Long row medium column
|
|
||||||
else {
|
|
||||||
args.reduce_shape.push_back(args.reduction_size);
|
args.reduce_shape.push_back(args.reduction_size);
|
||||||
args.reduce_strides.push_back(args.reduction_stride);
|
args.reduce_strides.push_back(args.reduction_stride);
|
||||||
args.reduce_ndim++;
|
args.reduce_ndim++;
|
||||||
int simdgroups =
|
|
||||||
(args.reduction_stride + REDUCE_N_READS - 1) / REDUCE_N_READS;
|
|
||||||
int threadgroup_size = simdgroups * 32;
|
|
||||||
auto out_grid_dims = output_grid_for_col_reduce(out, args);
|
|
||||||
grid_dims =
|
|
||||||
MTL::Size(threadgroup_size, out_grid_dims.width, out_grid_dims.height);
|
|
||||||
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the kernel
|
|
||||||
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
|
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
|
||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
const std::string func_name = "col_reduce_small";
|
const std::string func_name = "col_reduce_small";
|
||||||
@ -510,10 +497,113 @@ void strided_reduce_small(
|
|||||||
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n);
|
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n);
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
|
const int n_reads = 4;
|
||||||
|
size_t reduction_stride_blocks =
|
||||||
|
(args.reduction_stride + n_reads - 1) / n_reads;
|
||||||
|
size_t total = args.reduction_size * args.non_col_reductions;
|
||||||
|
size_t threadgroup_x = std::min(reduction_stride_blocks, 32ul);
|
||||||
|
size_t threadgroup_y = std::min(
|
||||||
|
8ul,
|
||||||
|
std::min(kernel->maxTotalThreadsPerThreadgroup() / threadgroup_x, total));
|
||||||
|
|
||||||
|
group_dims = MTL::Size(threadgroup_x, threadgroup_y, 1);
|
||||||
|
grid_dims = output_grid_for_col_reduce(out, args);
|
||||||
|
grid_dims = MTL::Size(
|
||||||
|
(reduction_stride_blocks + threadgroup_x - 1) / threadgroup_x,
|
||||||
|
grid_dims.width,
|
||||||
|
grid_dims.height);
|
||||||
|
|
||||||
// Launch
|
// Launch
|
||||||
compute_encoder.set_input_array(in, 0);
|
compute_encoder.set_input_array(in, 0);
|
||||||
compute_encoder.set_output_array(out, 1);
|
compute_encoder.set_output_array(out, 1);
|
||||||
args.encode(compute_encoder);
|
args.encode(compute_encoder);
|
||||||
|
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
void strided_reduce_longcolumn(
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
const std::string& op_name,
|
||||||
|
ColReduceArgs& args,
|
||||||
|
CommandEncoder& compute_encoder,
|
||||||
|
metal::Device& d,
|
||||||
|
const Stream& s) {
|
||||||
|
size_t total_reduction_size = args.reduction_size * args.non_col_reductions;
|
||||||
|
size_t outer_blocks = 32;
|
||||||
|
if (total_reduction_size >= 32768) {
|
||||||
|
outer_blocks = 128;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare the temporary accumulator
|
||||||
|
std::vector<int> intermediate_shape;
|
||||||
|
intermediate_shape.reserve(out.ndim() + 1);
|
||||||
|
intermediate_shape.push_back(outer_blocks);
|
||||||
|
intermediate_shape.insert(
|
||||||
|
intermediate_shape.end(), out.shape().begin(), out.shape().end());
|
||||||
|
array intermediate(std::move(intermediate_shape), out.dtype(), nullptr, {});
|
||||||
|
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
|
||||||
|
d.add_temporary(intermediate, s.index);
|
||||||
|
|
||||||
|
// Prepare the arguments for the kernel
|
||||||
|
args.reduce_shape.push_back(args.reduction_size);
|
||||||
|
args.reduce_strides.push_back(args.reduction_stride);
|
||||||
|
args.reduce_ndim++;
|
||||||
|
|
||||||
|
// Figure out the grid dims
|
||||||
|
size_t out_size = out.size();
|
||||||
|
size_t threadgroup_x = args.reduction_stride;
|
||||||
|
size_t threadgroup_y =
|
||||||
|
(args.non_col_reductions * args.reduction_size + outer_blocks - 1) /
|
||||||
|
outer_blocks;
|
||||||
|
threadgroup_y = std::min(32ul, threadgroup_y);
|
||||||
|
|
||||||
|
auto out_grid_size = output_grid_for_col_reduce(out, args);
|
||||||
|
MTL::Size grid_dims(out_grid_size.width, out_grid_size.height, outer_blocks);
|
||||||
|
MTL::Size group_dims(threadgroup_x, threadgroup_y, 1);
|
||||||
|
|
||||||
|
// Set the kernel
|
||||||
|
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
|
||||||
|
std::ostringstream kname;
|
||||||
|
const std::string func_name = "col_reduce_longcolumn";
|
||||||
|
kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in);
|
||||||
|
auto kernel =
|
||||||
|
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n);
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
|
// Launch
|
||||||
|
compute_encoder.set_input_array(in, 0);
|
||||||
|
compute_encoder.set_output_array(intermediate, 1);
|
||||||
|
args.encode(compute_encoder);
|
||||||
|
compute_encoder->setBytes(&out_size, sizeof(size_t), 11);
|
||||||
|
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||||
|
|
||||||
|
// Make the 2nd pass arguments and grid_dims
|
||||||
|
ColReduceArgs second_args(intermediate);
|
||||||
|
second_args.reduce_shape.push_back(outer_blocks);
|
||||||
|
second_args.reduce_strides.push_back(out.size());
|
||||||
|
second_args.reduce_ndim++;
|
||||||
|
int BN = 32;
|
||||||
|
grid_dims = MTL::Size(256 * ((out.size() + BN - 1) / BN), 1, 1);
|
||||||
|
group_dims = MTL::Size(256, 1, 1);
|
||||||
|
|
||||||
|
// Set the 2nd kernel
|
||||||
|
const std::string second_kernel = "col_reduce_looped_1_32_32_reduce_" +
|
||||||
|
op_name + type_to_name(intermediate);
|
||||||
|
kernel = get_reduce_kernel(
|
||||||
|
d,
|
||||||
|
second_kernel,
|
||||||
|
"col_reduce_looped",
|
||||||
|
op_name,
|
||||||
|
intermediate,
|
||||||
|
out,
|
||||||
|
1,
|
||||||
|
32,
|
||||||
|
32);
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
|
compute_encoder.set_input_array(intermediate, 0);
|
||||||
|
compute_encoder.set_output_array(out, 1);
|
||||||
|
second_args.encode(compute_encoder);
|
||||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -532,9 +622,9 @@ void strided_reduce_looped(
|
|||||||
|
|
||||||
// Figure out the grid dims
|
// Figure out the grid dims
|
||||||
auto out_grid_size = output_grid_for_col_reduce(out, args);
|
auto out_grid_size = output_grid_for_col_reduce(out, args);
|
||||||
int BN = (args.reduction_stride <= 1024) ? 32 : 128;
|
int BN = 32;
|
||||||
int BM = 1024 / BN;
|
int BM = 1024 / BN;
|
||||||
int threadgroup_size = 4 * 32;
|
int threadgroup_size = 8 * 32;
|
||||||
MTL::Size grid_dims(
|
MTL::Size grid_dims(
|
||||||
threadgroup_size * ((args.reduction_stride + BN - 1) / BN),
|
threadgroup_size * ((args.reduction_stride + BN - 1) / BN),
|
||||||
out_grid_size.width,
|
out_grid_size.width,
|
||||||
@ -558,6 +648,87 @@ void strided_reduce_looped(
|
|||||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void strided_reduce_2pass(
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
const std::string& op_name,
|
||||||
|
ColReduceArgs& args,
|
||||||
|
CommandEncoder& compute_encoder,
|
||||||
|
metal::Device& d,
|
||||||
|
const Stream& s) {
|
||||||
|
// Prepare the temporary accumulator
|
||||||
|
std::vector<int> intermediate_shape;
|
||||||
|
intermediate_shape.reserve(out.ndim() + 1);
|
||||||
|
intermediate_shape.push_back(32);
|
||||||
|
intermediate_shape.insert(
|
||||||
|
intermediate_shape.end(), out.shape().begin(), out.shape().end());
|
||||||
|
array intermediate(std::move(intermediate_shape), out.dtype(), nullptr, {});
|
||||||
|
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
|
||||||
|
d.add_temporary(intermediate, s.index);
|
||||||
|
|
||||||
|
// Prepare the arguments for the kernel
|
||||||
|
args.reduce_shape.push_back(args.reduction_size);
|
||||||
|
args.reduce_strides.push_back(args.reduction_stride);
|
||||||
|
args.reduce_ndim++;
|
||||||
|
|
||||||
|
// Figure out the grid dims
|
||||||
|
size_t out_size = out.size() / args.reduction_stride;
|
||||||
|
auto out_grid_size = output_grid_for_col_reduce(out, args);
|
||||||
|
int outer_blocks = 32;
|
||||||
|
int BN = 32;
|
||||||
|
int BM = 1024 / BN;
|
||||||
|
int threadgroup_size = 8 * 32;
|
||||||
|
MTL::Size grid_dims(
|
||||||
|
threadgroup_size * ((args.reduction_stride + BN - 1) / BN),
|
||||||
|
out_grid_size.width * outer_blocks,
|
||||||
|
out_grid_size.height);
|
||||||
|
MTL::Size group_dims(threadgroup_size, 1, 1);
|
||||||
|
|
||||||
|
// Set the kernel
|
||||||
|
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
|
||||||
|
std::ostringstream kname;
|
||||||
|
const std::string func_name = "col_reduce_2pass";
|
||||||
|
kname << func_name << "_" << n << "_" << BM << "_" << BN << "_reduce_"
|
||||||
|
<< op_name << type_to_name(in);
|
||||||
|
auto kernel =
|
||||||
|
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n, BM, BN);
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
|
// Launch
|
||||||
|
compute_encoder.set_input_array(in, 0);
|
||||||
|
compute_encoder.set_output_array(intermediate, 1);
|
||||||
|
args.encode(compute_encoder);
|
||||||
|
compute_encoder->setBytes(&out_size, sizeof(size_t), 11);
|
||||||
|
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||||
|
|
||||||
|
// Make the 2nd pass arguments and grid_dims
|
||||||
|
ColReduceArgs second_args(intermediate);
|
||||||
|
second_args.reduce_shape.push_back(outer_blocks);
|
||||||
|
second_args.reduce_strides.push_back(out.size());
|
||||||
|
second_args.reduce_ndim++;
|
||||||
|
grid_dims = MTL::Size(threadgroup_size * ((out.size() + BN - 1) / BN), 1, 1);
|
||||||
|
|
||||||
|
// Set the 2nd kernel
|
||||||
|
const std::string second_kernel = "col_reduce_looped_1_32_32_reduce_" +
|
||||||
|
op_name + type_to_name(intermediate);
|
||||||
|
kernel = get_reduce_kernel(
|
||||||
|
d,
|
||||||
|
second_kernel,
|
||||||
|
"col_reduce_looped",
|
||||||
|
op_name,
|
||||||
|
intermediate,
|
||||||
|
out,
|
||||||
|
1,
|
||||||
|
32,
|
||||||
|
32);
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
|
compute_encoder.set_input_array(intermediate, 0);
|
||||||
|
compute_encoder.set_output_array(out, 1);
|
||||||
|
second_args.encode(compute_encoder);
|
||||||
|
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
|
||||||
void strided_reduce_general_dispatch(
|
void strided_reduce_general_dispatch(
|
||||||
const array& in,
|
const array& in,
|
||||||
array& out,
|
array& out,
|
||||||
@ -570,11 +741,23 @@ void strided_reduce_general_dispatch(
|
|||||||
// Prepare the arguments for the kernel
|
// Prepare the arguments for the kernel
|
||||||
ColReduceArgs args(in, plan, axes);
|
ColReduceArgs args(in, plan, axes);
|
||||||
|
|
||||||
if (args.reduction_stride < 32 ||
|
// Small column
|
||||||
args.reduction_size * args.non_col_reductions < 32) {
|
if (args.reduction_size * args.non_col_reductions < 32) {
|
||||||
return strided_reduce_small(in, out, op_name, args, compute_encoder, d, s);
|
return strided_reduce_small(in, out, op_name, args, compute_encoder, d, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Long column but small row
|
||||||
|
if (args.reduction_stride < 32 &&
|
||||||
|
args.reduction_size * args.non_col_reductions >= 1024) {
|
||||||
|
return strided_reduce_longcolumn(
|
||||||
|
in, out, op_name, args, compute_encoder, d, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (args.reduction_size * args.non_col_reductions > 256 &&
|
||||||
|
out.size() / 32 < 1024) {
|
||||||
|
return strided_reduce_2pass(in, out, op_name, args, compute_encoder, d, s);
|
||||||
|
}
|
||||||
|
|
||||||
return strided_reduce_looped(in, out, op_name, args, compute_encoder, d, s);
|
return strided_reduce_looped(in, out, op_name, args, compute_encoder, d, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -620,7 +803,6 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
// Reduce
|
// Reduce
|
||||||
if (in.size() > 0) {
|
if (in.size() > 0) {
|
||||||
std::vector<array> copies;
|
|
||||||
ReductionPlan plan = get_reduction_plan(in, axes_);
|
ReductionPlan plan = get_reduction_plan(in, axes_);
|
||||||
|
|
||||||
// If it is a general reduce then copy the input to a contiguous array and
|
// If it is a general reduce then copy the input to a contiguous array and
|
||||||
@ -632,7 +814,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
if (plan.type == GeneralReduce) {
|
if (plan.type == GeneralReduce) {
|
||||||
array in_copy(in.shape(), in.dtype(), nullptr, {});
|
array in_copy(in.shape(), in.dtype(), nullptr, {});
|
||||||
copy_gpu(in, in_copy, CopyType::General, s);
|
copy_gpu(in, in_copy, CopyType::General, s);
|
||||||
copies.push_back(in_copy);
|
d.add_temporary(in_copy, s.index);
|
||||||
in = in_copy;
|
in = in_copy;
|
||||||
plan = get_reduction_plan(in, axes_);
|
plan = get_reduction_plan(in, axes_);
|
||||||
}
|
}
|
||||||
@ -640,7 +822,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
// Reducing over everything and the data is all there no broadcasting or
|
// Reducing over everything and the data is all there no broadcasting or
|
||||||
// slicing etc.
|
// slicing etc.
|
||||||
if (plan.type == ContiguousAllReduce) {
|
if (plan.type == ContiguousAllReduce) {
|
||||||
all_reduce_dispatch(in, out, op_name, compute_encoder, d, s, copies);
|
all_reduce_dispatch(in, out, op_name, compute_encoder, d, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
// At least the last dimension is row contiguous and we are reducing over
|
// At least the last dimension is row contiguous and we are reducing over
|
||||||
@ -659,8 +841,6 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
strided_reduce_general_dispatch(
|
strided_reduce_general_dispatch(
|
||||||
in, out, op_name, plan, axes_, compute_encoder, d, s);
|
in, out, op_name, plan, axes_, compute_encoder, d, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
d.add_temporaries(std::move(copies), s.index);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Nothing to reduce just initialize the output
|
// Nothing to reduce just initialize the output
|
||||||
|
@ -16,8 +16,7 @@ void all_reduce_dispatch(
|
|||||||
const std::string& op_name,
|
const std::string& op_name,
|
||||||
CommandEncoder& compute_encoder,
|
CommandEncoder& compute_encoder,
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const Stream& s,
|
const Stream& s);
|
||||||
std::vector<array>& copies);
|
|
||||||
|
|
||||||
void row_reduce_general_dispatch(
|
void row_reduce_general_dispatch(
|
||||||
const array& in,
|
const array& in,
|
||||||
|
Loading…
Reference in New Issue
Block a user