Refactor the reduction kernels (#277)

This commit is contained in:
Angelos Katharopoulos 2023-12-24 14:47:57 -08:00 committed by GitHub
parent 22fee5a383
commit 9e6b8c9f48
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 179 additions and 369 deletions

View File

@ -125,6 +125,14 @@ if __name__ == "__main__":
compare_filtered("sum_axis --size 16x128x1024 --axis 1")
compare_filtered("sum_axis --size 16x128x1024 --axis 0 --cpu")
compare_filtered("sum_axis --size 16x128x1024 --axis 0")
compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --cpu")
compare_filtered("sum_axis --size 16x128x1024 --axis 0,1")
compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --cpu")
compare_filtered("sum_axis --size 16x128x1024 --axis 0,2")
compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --transpose 0,2,1 --cpu")
compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --transpose 0,2,1")
compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --transpose 0,2,1 --cpu")
compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --transpose 0,2,1")
compare_filtered("argmax --size 10x1024x128 --axis 1 --cpu")
compare_filtered("argmax --size 10x1024x128 --axis 1")
compare_filtered("argmax --size 10x1024x128 --axis 2 --cpu")

View File

@ -126,7 +126,7 @@ struct ReductionPlan {
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
// The data is all there and we are reducing over everything
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
(x.flags().row_contiguous || x.flags().col_contiguous)) {
x.flags().contiguous) {
return ContiguousAllReduce;
}

View File

@ -112,80 +112,22 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
///////////////////////////////////////////////////////////////////////////////
// General reduce
///////////////////////////////////////////////////////////////////////////////
template <typename T, typename U, typename Op>
[[kernel]] void general_reduce(
const device T *in [[buffer(0)]],
device mlx_atomic<U> *out [[buffer(1)]],
const device int *in_shape [[buffer(2)]],
const device size_t *in_strides [[buffer(3)]],
const device size_t *out_strides [[buffer(4)]],
const device size_t& ndim [[buffer(5)]],
uint gid [[thread_position_in_grid]]) {
Op op;
auto in_idx = elem_to_loc(gid, in_shape, in_strides, ndim);
auto out_idx = elem_to_loc(gid, in_shape, out_strides, ndim);
op.atomic_update(out, static_cast<U>(in[in_idx]), out_idx);
}
template <typename T, typename U, typename Op, int NDIM>
[[kernel]] void general_reduce(
const device T *in [[buffer(0)]],
device mlx_atomic<U> *out [[buffer(1)]],
const device int *in_shape [[buffer(2)]],
const device size_t *in_strides [[buffer(3)]],
const device size_t *out_strides [[buffer(4)]],
uint gid [[thread_position_in_grid]]) {
Op op;
auto in_idx = elem_to_loc_nd<NDIM>(gid, in_shape, in_strides);
auto out_idx = elem_to_loc_nd<NDIM>(gid, in_shape, out_strides);
op.atomic_update(out, static_cast<U>(in[in_idx]), out_idx);
}
#define instantiate_general_reduce_helper(name, itype, otype, op) \
template [[host_name("general_reduce_" #name)]] \
[[kernel]] void general_reduce<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device mlx_atomic<otype> *out [[buffer(1)]], \
const device int *in_shape [[buffer(2)]], \
const device size_t *in_strides [[buffer(3)]], \
const device size_t *out_strides [[buffer(4)]], \
const device size_t& ndim [[buffer(5)]], \
uint gid [[thread_position_in_grid]]);
#define instantiate_general_reduce_helper_nd(name, itype, otype, op, n) \
template [[host_name("general_reduce_" #name "_dim_" #n)]] \
[[kernel]] void general_reduce<itype, otype, op, n>( \
const device itype *in [[buffer(0)]], \
device mlx_atomic<otype> *out [[buffer(1)]], \
const device int *in_shape [[buffer(2)]], \
const device size_t *in_strides [[buffer(3)]], \
const device size_t *out_strides [[buffer(4)]], \
uint gid [[thread_position_in_grid]]);
#define instantiate_general_reduce(name, itype, otype, op) \
instantiate_general_reduce_helper(name, itype, otype, op) \
instantiate_general_reduce_helper_nd(name, itype, otype, op, 1) \
instantiate_general_reduce_helper_nd(name, itype, otype, op, 2) \
instantiate_general_reduce_helper_nd(name, itype, otype, op, 3) \
instantiate_general_reduce_helper_nd(name, itype, otype, op, 4)
///////////////////////////////////////////////////////////////////////////////
// Row atomics
///////////////////////////////////////////////////////////////////////////////
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
[[kernel]] void row_reduce(
[[kernel]] void row_reduce_general(
const device T *in [[buffer(0)]],
device U *out [[buffer(1)]],
const device size_t& reduction_size [[buffer(2)]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint tid [[threadgroup_position_in_grid]],
device mlx_atomic<U> *out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant int* shape [[buffer(4)]],
const constant size_t* strides [[buffer(5)]],
const constant int& ndim [[buffer(6)]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
@ -193,7 +135,10 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
Op op;
// Each threadgroup handles 1 reduction
in += tid * reduction_size + lid * N_READS;
// TODO: Specializing elem_to_loc would be slightly faster
int idx = tid.y * out_size + tid.x;
int extra_offset = elem_to_loc(idx, shape, strides, ndim);
in += extra_offset + lid.x * N_READS;
// The reduction is accumulated here
U total_val = Op::init;
@ -201,7 +146,7 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
// Loop over the reduction size within thread group
int r = 0;
for (; r < (int)ceildiv(reduction_size, N_READS*lsize) - 1; r++) {
for (; r < (int)ceildiv(reduction_size, N_READS*lsize.x) - 1; r++) {
T vals[N_READS];
for(int i = 0; i < N_READS; i++) {
vals[i] = in[i];
@ -210,11 +155,11 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
total_val = op(static_cast<U>(vals[i]), total_val);
}
in += lsize * N_READS;
in += lsize.x * N_READS;
}
// Sepate case for the last set as we close the reduction size
size_t reduction_index = (lid + (size_t)lsize * r) * N_READS;
// Separate case for the last set as we close the reduction size
size_t reduction_index = (lid.x + (size_t)lsize.x * r) * N_READS;
if(reduction_index < reduction_size) {
int max_reads = reduction_size - reduction_index;
@ -240,26 +185,30 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
// Reduction within thread group
// Only needed if multiple simd groups
if(reduction_size > simd_size) {
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init;
total_val = op.simd_reduce(total_val);
}
// Update output
if (lid == 0) {
out[tid] = total_val;
if (lid.x == 0) {
op.atomic_update(out, total_val, tid.x);
}
}
#define instantiate_row_reduce(name, itype, otype, op) \
template [[host_name("row_reduce_" #name)]] \
[[kernel]] void row_reduce<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device otype *out [[buffer(1)]], \
const device size_t& reduction_size [[buffer(2)]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint tid [[threadgroup_position_in_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
#define instantiate_row_reduce_general(name, itype, otype, op) \
template [[host_name("row_reduce_general_" #name)]] \
[[kernel]] void row_reduce_general<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device mlx_atomic<otype> *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant int* shape [[buffer(4)]], \
const constant size_t* strides [[buffer(5)]], \
const constant int& ndim [[buffer(6)]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
@ -311,62 +260,26 @@ inline void _contiguous_strided_reduce(
}
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
[[kernel]] void col_reduce(
[[kernel]] void col_reduce_general(
const device T *in [[buffer(0)]],
device mlx_atomic<U> *out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
threadgroup U *local_data [[threadgroup(0)]],
uint2 tid [[threadgroup_position_in_grid]],
uint2 lid [[thread_position_in_threadgroup]],
uint2 lsize [[threads_per_threadgroup]]) {
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]]) {
auto out_idx = tid.x * lsize.x + lid.x;
if(out_idx < out_size) {
_contiguous_strided_reduce<T, U, Op, N_READS>(
in,
out,
local_data,
out_idx,
out_idx,
reduction_size,
reduction_stride,
tid,
lid,
lsize);
}
}
#define instantiate_col_reduce(name, itype, otype, op) \
template [[host_name("col_reduce_" #name)]] \
[[kernel]] void col_reduce<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device mlx_atomic<otype> *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
threadgroup otype *local_data [[threadgroup(0)]], \
uint2 tid [[threadgroup_position_in_grid]], \
uint2 lid [[thread_position_in_threadgroup]], \
uint2 lsize [[threads_per_threadgroup]]);
template <typename T, typename U, typename Op, int NDIM, int N_READS = 16>
[[kernel]] void contiguous_strided_reduce(
const device T *in [[buffer(0)]],
device mlx_atomic<U> *out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const device int* in_shape [[buffer(5)]],
const device size_t* in_strides [[buffer(6)]],
threadgroup U *local_data [[threadgroup(0)]],
uint2 tid [[threadgroup_position_in_grid]],
uint2 lid [[thread_position_in_threadgroup]],
uint2 lsize [[threads_per_threadgroup]]) {
auto out_idx = tid.x * lsize.x + lid.x;
auto in_idx = elem_to_loc_nd<NDIM>(out_idx, in_shape, in_strides);
auto in_idx = elem_to_loc(
out_idx + tid.z * out_size,
shape,
strides,
ndim
);
if(out_idx < out_size) {
_contiguous_strided_reduce<T, U, Op, N_READS>(
@ -377,82 +290,27 @@ template <typename T, typename U, typename Op, int NDIM, int N_READS = 16>
out_idx,
reduction_size,
reduction_stride,
tid,
lid,
lsize);
tid.xy,
lid.xy,
lsize.xy);
}
}
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
[[kernel]] void contiguous_strided_reduce(
const device T *in [[buffer(0)]],
device mlx_atomic<U> *out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const device int* in_shape [[buffer(5)]],
const device size_t* in_strides [[buffer(6)]],
const device size_t& in_dim [[buffer(7)]],
threadgroup U *local_data [[threadgroup(0)]],
uint2 tid [[threadgroup_position_in_grid]],
uint2 lid [[thread_position_in_threadgroup]],
uint2 lsize [[threads_per_threadgroup]]) {
auto out_idx = tid.x * lsize.x + lid.x;
auto in_idx = elem_to_loc(out_idx, in_shape, in_strides, in_dim);
if(out_idx < out_size) {
_contiguous_strided_reduce<T, U, Op, N_READS>(
in,
out,
local_data,
in_idx,
out_idx,
reduction_size,
reduction_stride,
tid,
lid,
lsize);
}
}
#define instantiate_contiguous_strided_helper(name, itype, otype, op) \
template [[host_name("contiguous_strided_reduce_" #name)]] \
[[kernel]] void contiguous_strided_reduce<itype, otype, op>( \
#define instantiate_col_reduce_general(name, itype, otype, op) \
template [[host_name("col_reduce_general_" #name)]] \
[[kernel]] void col_reduce_general<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device mlx_atomic<otype> *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const device int* in_shape [[buffer(5)]], \
const device size_t* in_strides [[buffer(6)]], \
const device size_t& in_dim [[buffer(7)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
threadgroup otype *local_data [[threadgroup(0)]], \
uint2 tid [[threadgroup_position_in_grid]], \
uint2 lid [[thread_position_in_threadgroup]], \
uint2 lsize [[threads_per_threadgroup]]);
#define instantiate_contiguous_strided_helper_nd(name, itype, otype, op, n) \
template [[host_name("contiguous_strided_reduce_" #name "_dim_" #n)]] \
[[kernel]] void contiguous_strided_reduce<itype, otype, op, n>( \
const device itype *in [[buffer(0)]], \
device mlx_atomic<otype> *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const device int* in_shape [[buffer(5)]], \
const device size_t* in_strides [[buffer(6)]], \
threadgroup otype *local_data [[threadgroup(0)]], \
uint2 tid [[threadgroup_position_in_grid]], \
uint2 lid [[thread_position_in_threadgroup]], \
uint2 lsize [[threads_per_threadgroup]]);
#define instantiate_contiguous_strided(name, itype, otype, op) \
instantiate_contiguous_strided_helper(name, itype, otype, op) \
instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 1) \
instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 2) \
instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 3) \
instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 4)
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]]);
///////////////////////////////////////////////////////////////////////////////
@ -461,10 +319,8 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
#define instantiate_reduce(name, itype, otype, op) \
instantiate_all_reduce(name, itype, otype, op) \
instantiate_row_reduce(name, itype, otype, op) \
instantiate_col_reduce(name, itype, otype, op) \
instantiate_contiguous_strided(name, itype, otype, op) \
instantiate_general_reduce(name, itype, otype, op)
instantiate_row_reduce_general(name, itype, otype, op) \
instantiate_col_reduce_general(name, itype, otype, op)
#define instantiate_same_reduce(name, tname, type, op) \
instantiate_init_reduce(name ##tname, type, op<type>) \

View File

@ -2,9 +2,11 @@
#include <algorithm>
#include <cassert>
#include <iostream>
#include <sstream>
#include "mlx/backend/common/reduce.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/utils.h"
@ -61,22 +63,47 @@ void all_reduce_dispatch(
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
void row_reduce_dispatch(
void row_reduce_general_dispatch(
const array& in,
array& out,
const std::string& op_name,
const std::vector<int>& axes_,
const ReductionPlan& plan,
const std::vector<int>& axes,
MTL::ComputeCommandEncoder* compute_encoder,
metal::Device& d) {
auto kernel = d.get_kernel("row_reduce_" + op_name + type_to_name(in));
auto kernel =
d.get_kernel("row_reduce_general_" + op_name + type_to_name(in));
// Prepare the arguments for the kernel
int n_reads = REDUCE_N_READS;
size_t reduction_size = in.size() / out.size();
size_t reduction_size = plan.shape.back();
size_t out_size = out.size();
auto shape = plan.shape;
auto strides = plan.strides;
shape.pop_back();
strides.pop_back();
size_t non_row_reductions = 1;
for (auto s : shape) {
non_row_reductions *= static_cast<size_t>(s);
}
auto [rem_shape, rem_strides] = shapes_without_reduction_axes(in, axes);
for (auto s : rem_shape) {
shape.push_back(s);
}
for (auto s : rem_strides) {
strides.push_back(s);
}
int ndim = shape.size();
// Set the arguments for the kernel
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4);
compute_encoder->setBytes(strides.data(), strides.size() * sizeof(size_t), 5);
compute_encoder->setBytes(&ndim, sizeof(int), 6);
// Each thread group is responsible for 1 output
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
@ -91,92 +118,54 @@ void row_reduce_dispatch(
// Launch enough thread groups for each output
size_t n_threads = out.size() * thread_group_size;
MTL::Size grid_dims = MTL::Size(n_threads, 1, 1);
MTL::Size grid_dims = MTL::Size(n_threads, non_row_reductions, 1);
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
void col_reduce_dispatch(
void strided_reduce_general_dispatch(
const array& in,
array& out,
const std::string& op_name,
const std::vector<int>& axes_,
const ReductionPlan& plan,
const std::vector<int>& axes,
MTL::ComputeCommandEncoder* compute_encoder,
metal::Device& d) {
std::ostringstream kernel_name;
auto kernel =
d.get_kernel("col_reduce_general_" + op_name + type_to_name(in));
bool encode_in_shape = false;
bool encode_ndim = false;
// If the slowest moving axis can be merged into the reductions,
// we call the column reduce kernel
// In this case, a linear index in the output corresponds to the
// linear index in the input where the reduction starts
if (axes_[axes_.size() - 1] == (axes_.size() - 1)) {
kernel_name << "col_reduce_" << op_name << type_to_name(in);
}
// Otherwise, while all the reduction axes can be merged, the mapping between
// indices in the output and input require resolving using shapes and strides
else {
kernel_name << "contiguous_strided_reduce_" << op_name << type_to_name(in);
encode_in_shape = true;
// We check for a viable template with the required number of dimensions
// we only care about encoding non-reduced shapes and strides in the input
size_t non_reducing_dims = in.ndim() - axes_.size();
if (non_reducing_dims >= 1 &&
non_reducing_dims <= MAX_REDUCE_SPECIALIZED_DIMS) {
kernel_name << "_dim_" << non_reducing_dims;
} else {
encode_ndim = true;
}
}
auto kernel = d.get_kernel(kernel_name.str());
size_t in_size = in.size();
// Prepare the arguments for the kernel
size_t reduction_size = plan.shape.back();
size_t reduction_stride = plan.strides.back();
size_t out_size = out.size();
auto shape = plan.shape;
auto strides = plan.strides;
shape.pop_back();
strides.pop_back();
size_t non_col_reductions = 1;
for (auto s : shape) {
non_col_reductions *= static_cast<size_t>(s);
}
auto [rem_shape, rem_strides] = shapes_without_reduction_axes(in, axes);
for (auto s : rem_shape) {
shape.push_back(s);
}
for (auto s : rem_strides) {
strides.push_back(s);
}
int ndim = shape.size();
// Set the arguments for the kernel
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
// Calculate the number of inputs to reduce and the stride b/w them
size_t reduction_size = 1;
size_t in_ndim = in.ndim();
size_t reduction_stride = in_size;
for (int i : axes_) {
reduction_size *= in.shape(i);
reduction_stride = std::min(reduction_stride, in.strides()[i]);
}
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3);
compute_encoder->setBytes(&out_size, sizeof(size_t), 4);
if (encode_in_shape) {
// Obtain the non-reducing shape and strides of the input to encode
std::vector<int> inp_shape_mod;
std::vector<size_t> inp_strides_mod;
for (size_t i = 0, j = 0; i < in.ndim(); i++) {
if (j < axes_.size() && axes_[j] == i) {
j++;
} else {
inp_shape_mod.push_back(in.shape(i));
inp_strides_mod.push_back(in.strides()[i]);
}
}
size_t ndim = inp_shape_mod.size();
compute_encoder->setBytes(inp_shape_mod.data(), ndim * sizeof(int), 5);
compute_encoder->setBytes(inp_strides_mod.data(), ndim * sizeof(size_t), 6);
if (encode_ndim) {
compute_encoder->setBytes(&ndim, sizeof(size_t), 7);
}
}
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5);
compute_encoder->setBytes(strides.data(), strides.size() * sizeof(size_t), 6);
compute_encoder->setBytes(&ndim, sizeof(int), 7);
// Select block dimensions
@ -200,7 +189,8 @@ void col_reduce_dispatch(
(n_threads_per_output + threadgroup_dim_y - 1) / threadgroup_dim_y;
// Launch enough thread groups for each output
MTL::Size grid_dims = MTL::Size(n_threadgroups_x, n_threadgroups_y, 1);
MTL::Size grid_dims =
MTL::Size(n_threadgroups_x, n_threadgroups_y, non_col_reductions);
MTL::Size group_dims = MTL::Size(threadgroup_dim_x, threadgroup_dim_y, 1);
// We set shared memory to be exploited here for reductions within a
@ -216,60 +206,6 @@ void col_reduce_dispatch(
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
void general_reduce_dispatch(
const array& in,
array& out,
const std::string& op_name,
const std::vector<int>& axes_,
MTL::ComputeCommandEncoder* compute_encoder,
metal::Device& d) {
bool encode_ndim = true;
std::ostringstream kernel_name;
kernel_name << "general_reduce_" << op_name << type_to_name(in);
// Check for specialzed kernels for input ndim
if (in.ndim() >= 1 && in.ndim() <= MAX_REDUCE_SPECIALIZED_DIMS) {
kernel_name << "_dim_" << in.ndim();
encode_ndim = false;
}
auto kernel = d.get_kernel(kernel_name.str());
size_t in_size = in.size();
size_t ndim = in.ndim();
// We set the reducing strides to 0 to induce collisions for the reduction
std::vector<size_t> out_strides(ndim);
size_t stride = 1;
for (int i = ndim - 1, j = axes_.size() - 1; i >= 0; --i) {
if (j >= 0 && axes_[j] == i) {
out_strides[i] = 0;
--j;
} else {
out_strides[i] = stride;
stride *= in.shape(i);
}
}
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(in.shape().data(), ndim * sizeof(int), 2);
compute_encoder->setBytes(in.strides().data(), ndim * sizeof(size_t), 3);
compute_encoder->setBytes(out_strides.data(), ndim * sizeof(size_t), 4);
if (encode_ndim) {
compute_encoder->setBytes(&ndim, sizeof(size_t), 5);
}
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > in_size) {
thread_group_size = in_size;
}
size_t nthreads = in_size;
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
} // namespace
//////////////////////////////////////////////////////////////////////
@ -278,7 +214,7 @@ void general_reduce_dispatch(
void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
array in = inputs[0];
// TODO: Allow specific row and column reductions with types disabled
// due to atomics ?
@ -335,36 +271,46 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
// Reduce
{
// Check for contiguous data
if (in.size() == in.data_size() &&
(in.flags().row_contiguous || in.flags().col_contiguous)) {
// Go to all reduce if reducing over all axes
if (axes_.size() == in.ndim()) {
all_reduce_dispatch(in, out, op_name, compute_encoder, d);
return;
}
// Use specialized kernels if the input is row contiguous and
// the reducing axes can be merged into one
else if (
in.flags().row_contiguous && in.strides().back() == 1 &&
(axes_.back() - axes_.front()) == axes_.size() - 1) {
// If the fastest moving axis is being reduced, go to row reduce
if (axes_[0] == (in.ndim() - axes_.size())) {
row_reduce_dispatch(in, out, op_name, axes_, compute_encoder, d);
return;
}
// Otherwise go to to generalized strided reduce
// Note: bool isn't support here yet due to the use of atomics
// once that is updated, this should be the else condition of this
// branch
else if (in.dtype() != bool_) {
col_reduce_dispatch(in, out, op_name, axes_, compute_encoder, d);
return;
}
}
std::vector<array> copies;
ReductionPlan plan = get_reduction_plan(in, axes_);
// If it is a general reduce then copy the input to a contiguous array and
// recompute the plan.
if (plan.type == GeneralReduce) {
array in_copy(in.shape(), in.dtype(), nullptr, {});
copy_gpu(in, in_copy, CopyType::General, s);
copies.push_back(in_copy);
in = in_copy;
plan = get_reduction_plan(in, axes_);
}
// Reducing over everything and the data is all there no broadcasting or
// slicing etc.
if (plan.type == ContiguousAllReduce) {
all_reduce_dispatch(in, out, op_name, compute_encoder, d);
}
// At least the last dimension is row contiguous and we are reducing over
// the last dim.
else if (
plan.type == ContiguousReduce || plan.type == GeneralContiguousReduce) {
row_reduce_general_dispatch(
in, out, op_name, plan, axes_, compute_encoder, d);
}
// At least the last two dimensions are contiguous and we are doing a
// strided reduce over these.
else if (
plan.type == ContiguousStridedReduce ||
plan.type == GeneralStridedReduce) {
strided_reduce_general_dispatch(
in, out, op_name, plan, axes_, compute_encoder, d);
}
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
}
// Fall back to the general case
general_reduce_dispatch(in, out, op_name, axes_, compute_encoder, d);
}
}