Add GPU support for uint64/int64 reductions (#569)

This commit is contained in:
Vijay Krish 2024-01-31 11:18:04 -08:00 committed by GitHub
parent bad67fec37
commit fcc5ac1c64
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 654 additions and 192 deletions

View File

@ -63,18 +63,6 @@ struct ArgMax {
}
};
bool simd_shuffle_down(bool data, uint16_t delta) {
return simd_shuffle_down(static_cast<uint32_t>(data), delta);
}
uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) {
return as_type<uint64_t>(simd_shuffle_down(as_type<uint2>(data), delta));
}
int64_t simd_shuffle_down(int64_t data, uint16_t delta) {
return as_type<int64_t>(simd_shuffle_down(as_type<uint2>(data), delta));
}
template <typename U>
IndexValPair<U> simd_shuffle_down(IndexValPair<U> data, uint16_t delta) {
return IndexValPair<U>(

View File

@ -24,31 +24,20 @@ template <typename T, typename Op>
device otype *out [[buffer(1)]], \
uint tid [[thread_position_in_grid]]);
///////////////////////////////////////////////////////////////////////////////
// All reduce
///////////////////////////////////////////////////////////////////////////////
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
[[kernel]] void all_reduce(
const device T *in [[buffer(0)]],
device mlx_atomic<U> *out [[buffer(1)]],
const device size_t& in_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint grid_size [[threads_per_grid]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
// NB: this kernel assumes threads_per_threadgroup is at most
// 1024. This way with a simd_size of 32, we are guaranteed to
// complete the reduction in two steps of simd-level reductions.
inline U per_thread_all_reduce(
const device T *in,
const device size_t& in_size,
uint gid,
uint grid_size) {
Op op;
threadgroup U local_vals[simd_size];
U total_val = Op::init;
if (gid * N_READS < in_size) {
in += gid * N_READS;
int r = 0;
@ -80,6 +69,30 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
total_val = op(static_cast<U>(val), total_val);
}
}
}
return total_val;
}
// NB: This kernel assumes threads_per_threadgroup is at most
// 1024. This way with a simd_size of 32, we are guaranteed to
// complete the reduction in two steps of simd-level reductions.
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
[[kernel]] void all_reduce(
const device T *in [[buffer(0)]],
device mlx_atomic<U> *out [[buffer(1)]],
const device size_t& in_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint grid_size [[threads_per_grid]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
Op op;
threadgroup U local_vals[simd_size];
U total_val = per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
// Reduction within simd group
total_val = op.simd_reduce(total_val);
@ -98,6 +111,46 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
}
}
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
[[kernel]] void all_reduce_no_atomics(
const device T *in [[buffer(0)]],
device U *out [[buffer(1)]],
const device size_t& in_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint grid_size [[threads_per_grid]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint thread_group_id [[threadgroup_position_in_grid]]) {
Op op;
threadgroup U local_vals[simd_size];
U total_val = per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
// Reduction within simd group (simd_add isn't supported for uint64/int64 types)
for (uint16_t lane_offset = simd_size/2; lane_offset > 0; lane_offset /= 2) {
total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
}
// Write simd group reduction results to local memory
if (simd_lane_id == 0) {
local_vals[simd_group_id] = total_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Reduction of simdgroup reduction results within threadgroup.
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
for (uint16_t lane_offset = simd_size/2; lane_offset > 0; lane_offset /= 2) {
total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
}
// Reduction across threadgroups
if (lid == 0) {
out[thread_group_id] = total_val;
}
}
#define instantiate_all_reduce(name, itype, otype, op) \
template [[host_name("all_reduce_" #name)]] \
[[kernel]] void all_reduce<itype, otype, op>( \
@ -111,11 +164,80 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_all_reduce_no_atomics(name, itype, otype, op) \
template [[host_name("all_reduce_no_atomics_" #name)]] \
[[kernel]] void all_reduce_no_atomics<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device otype *out [[buffer(1)]], \
const device size_t& in_size [[buffer(2)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint grid_size [[threads_per_grid]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint thread_group_id [[threadgroup_position_in_grid]]);
///////////////////////////////////////////////////////////////////////////////
// Row atomics
///////////////////////////////////////////////////////////////////////////////
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
inline U per_thread_row_reduce(
const device T *in,
const constant size_t& reduction_size,
const constant size_t& out_size,
const constant int* shape,
const constant size_t* strides,
const constant int& ndim,
uint lsize_x,
uint lid_x,
uint2 tid) {
Op op;
// Each threadgroup handles 1 reduction
// TODO: Specializing elem_to_loc would be slightly faster
int idx = tid.y * out_size + tid.x;
int extra_offset = elem_to_loc(idx, shape, strides, ndim);
in += extra_offset + lid_x * N_READS;
// The reduction is accumulated here
U total_val = Op::init;
// Loop over the reduction size within thread group
int r = 0;
for (; r < (int)ceildiv(reduction_size, N_READS*lsize_x) - 1; r++) {
T vals[N_READS];
for(int i = 0; i < N_READS; i++) {
vals[i] = in[i];
}
for(int i = 0; i < N_READS; i++) {
total_val = op(static_cast<U>(vals[i]), total_val);
}
in += lsize_x * N_READS;
}
// Separate case for the last set as we close the reduction size
size_t reduction_index = (lid_x + (size_t)lsize_x * r) * N_READS;
if(reduction_index < reduction_size) {
int max_reads = reduction_size - reduction_index;
T vals[N_READS];
for(int i = 0; i < N_READS; i++) {
int idx = min(i, max_reads - 1);
vals[i] = static_cast<U>(in[idx]);
}
for(int i = 0; i < N_READS; i++) {
T val = i < max_reads ? vals[i] : Op::init;
total_val = op(static_cast<U>(val), total_val);
}
}
return total_val;
}
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
[[kernel]] void row_reduce_general(
const device T *in [[buffer(0)]],
@ -133,46 +255,9 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
Op op;
// Each threadgroup handles 1 reduction
// TODO: Specializing elem_to_loc would be slightly faster
int idx = tid.y * out_size + tid.x;
int extra_offset = elem_to_loc(idx, shape, strides, ndim);
in += extra_offset + lid.x * N_READS;
// The reduction is accumulated here
U total_val = Op::init;
threadgroup U local_vals[simd_size];
// Loop over the reduction size within thread group
int r = 0;
for (; r < (int)ceildiv(reduction_size, N_READS*lsize.x) - 1; r++) {
T vals[N_READS];
for(int i = 0; i < N_READS; i++) {
vals[i] = in[i];
}
for(int i = 0; i < N_READS; i++) {
total_val = op(static_cast<U>(vals[i]), total_val);
}
in += lsize.x * N_READS;
}
// Separate case for the last set as we close the reduction size
size_t reduction_index = (lid.x + (size_t)lsize.x * r) * N_READS;
if(reduction_index < reduction_size) {
int max_reads = reduction_size - reduction_index;
T vals[N_READS];
for(int i = 0; i < N_READS; i++) {
int idx = min(i, max_reads - 1);
vals[i] = static_cast<U>(in[idx]);
}
for(int i = 0; i < N_READS; i++) {
T val = i < max_reads ? vals[i] : Op::init;
total_val = op(static_cast<U>(val), total_val);
}
}
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(in, reduction_size, out_size, shape, strides, ndim, lsize.x, lid.x, tid.xy);
total_val = op.simd_reduce(total_val);
@ -194,6 +279,53 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
}
}
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
[[kernel]] void row_reduce_general_no_atomics(
const device T *in [[buffer(0)]],
device U *out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant int* shape [[buffer(4)]],
const constant size_t* strides [[buffer(5)]],
const constant int& ndim [[buffer(6)]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]],
uint3 gsize [[threads_per_grid]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
Op op;
threadgroup U local_vals[simd_size];
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(in, reduction_size, out_size, shape, strides, ndim, lsize.x, lid.x, tid.xy);
// Reduction within simd group - simd_add isn't supported for int64 types
for (uint16_t i = simd_size/2; i > 0; i /= 2) {
total_val = op(total_val, simd_shuffle_down(total_val, i));
}
// Prepare next level
if (simd_lane_id == 0) {
local_vals[simd_group_id] = total_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Reduction within thread group
// Only needed if thread group has multiple simd groups
if(ceildiv(reduction_size, N_READS) > simd_size) {
total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init;
for (uint16_t i = simd_size/2; i > 0; i /= 2) {
total_val = op(total_val, simd_shuffle_down(total_val, i));
}
}
// Write row reduce output for threadgroup with 1st thread in thread group
if (lid.x == 0) {
out[(ceildiv(gsize.y, lsize.y) * tid.x) + tid.y] = total_val;
}
}
#define instantiate_row_reduce_general(name, itype, otype, op) \
template [[host_name("row_reduce_general_" #name)]] \
[[kernel]] void row_reduce_general<itype, otype, op>( \
@ -211,18 +343,33 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \
template [[host_name("row_reduce_general_no_atomics_" #name)]] \
[[kernel]] void row_reduce_general_no_atomics<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device otype *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant int* shape [[buffer(4)]], \
const constant size_t* strides [[buffer(5)]], \
const constant int& ndim [[buffer(6)]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 gsize [[threads_per_grid]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
///////////////////////////////////////////////////////////////////////////////
// Column reduce
///////////////////////////////////////////////////////////////////////////////
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
inline void _contiguous_strided_reduce(
inline U _contiguous_strided_reduce(
const device T *in,
device mlx_atomic<U> *out,
threadgroup U *local_data,
uint in_idx,
uint out_idx,
uint reduction_size,
uint reduction_stride,
uint2 tid,
@ -230,33 +377,25 @@ inline void _contiguous_strided_reduce(
uint2 lsize) {
Op op;
T local_vals[N_READS];
U total_val = Op::init;
uint base_offset = (tid.y * lsize.y + lid.y) * N_READS;
for(uint r = 0; r < N_READS; r++) {
uint offset = base_offset + r;
offset = offset < reduction_size ? offset : reduction_size - 1;
local_vals[r] = in[in_idx + offset * reduction_stride];
}
U total_val = Op::init;
for(uint r = 0; r < N_READS && (base_offset + r) < reduction_size; r++) {
total_val = op(static_cast<U>(total_val), local_vals[r]);
uint offset = base_offset + r;
total_val = op(static_cast<U>(total_val), in[in_idx + offset * reduction_stride]);
}
local_data[lsize.y * lid.x + lid.y] = total_val;
threadgroup_barrier(mem_flags::mem_threadgroup);
U val = Op::init;
if(lid.y == 0) {
U val = op.init;
// Perform reduction across columns in thread group
for(uint i = 0; i < lsize.y; i++) {
val = op(val, local_data[lsize.y * lid.x + i]);
}
op.atomic_update(out, val, out_idx);
}
return val;
}
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
@ -281,18 +420,66 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
ndim
);
Op op;
if(out_idx < out_size) {
_contiguous_strided_reduce<T, U, Op, N_READS>(
U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
in,
out,
local_data,
in_idx,
out_idx,
reduction_size,
reduction_stride,
tid.xy,
lid.xy,
lsize.xy);
// Write out reduction results generated by threadgroups working on specific output element, contiguously.
if (lid.y == 0) {
op.atomic_update(out, val, out_idx);
}
}
}
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
[[kernel]] void col_reduce_general_no_atomics(
const device T *in [[buffer(0)]],
device U *out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
threadgroup U *local_data [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 gid [[thread_position_in_grid]],
uint3 lsize [[threads_per_threadgroup]],
uint3 gsize [[threads_per_grid]]) {
auto out_idx = tid.x * lsize.x + lid.x;
auto in_idx = elem_to_loc(
out_idx + tid.z * out_size,
shape,
strides,
ndim
);
if(out_idx < out_size) {
U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
in,
local_data,
in_idx,
reduction_size,
reduction_stride,
tid.xy,
lid.xy,
lsize.xy);
// Write out reduction results generated by threadgroups working on specific output element, contiguously.
if (lid.y == 0) {
uint tgsize_y = ceildiv(gsize.y, lsize.y);
uint tgsize_z = ceildiv(gsize.z, lsize.z);
out[tgsize_y * tgsize_z * gid.x + tgsize_y * tid.z + tid.y] = val;
}
}
}
@ -312,6 +499,23 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]]);
#define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \
template [[host_name("col_reduce_general_no_atomics_" #name)]] \
[[kernel]] void col_reduce_general_no_atomics<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device otype *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
threadgroup otype *local_data [[threadgroup(0)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 gid [[thread_position_in_grid]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 gsize [[threads_per_grid]]);
///////////////////////////////////////////////////////////////////////////////
// Instantiations
@ -322,6 +526,15 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
instantiate_row_reduce_general(name, itype, otype, op) \
instantiate_col_reduce_general(name, itype, otype, op)
#define instantiate_reduce_no_atomics(name, itype, otype, op) \
instantiate_all_reduce_no_atomics(name, itype, otype, op) \
instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \
instantiate_col_reduce_general_no_atomics(name, itype, otype, op)
#define instantiate_same_reduce_no_atomics(name, tname, type, op) \
instantiate_init_reduce(name ##tname, type, op<type>) \
instantiate_reduce_no_atomics(name ##tname, type, type, op<type>)
#define instantiate_same_reduce(name, tname, type, op) \
instantiate_init_reduce(name ##tname, type, op<type>) \
instantiate_reduce(name ##tname, type, type, op<type>)
@ -353,6 +566,9 @@ instantiate_same_reduce(sum, int32, int32_t, Sum)
instantiate_same_reduce(sum, float16, half, Sum)
instantiate_same_reduce(sum, float32, float, Sum)
instantiate_same_reduce_no_atomics(sum, int64, int64_t, Sum)
instantiate_same_reduce_no_atomics(sum, uint64, uint64_t, Sum)
instantiate_same_reduce(prod, uint8, uint8_t, Prod)
instantiate_same_reduce(prod, uint16, uint16_t, Prod)
instantiate_same_reduce(prod, uint32, uint32_t, Prod)
@ -362,6 +578,9 @@ instantiate_same_reduce(prod, int32, int32_t, Prod)
instantiate_same_reduce(prod, float16, half, Prod)
instantiate_same_reduce(prod, float32, float, Prod)
instantiate_same_reduce_no_atomics(prod, int64, int64_t, Prod)
instantiate_same_reduce_no_atomics(prod, uint64, uint64_t, Prod)
instantiate_same_reduce(sum, bfloat16, bfloat16_t, Sum)
instantiate_same_reduce(prod, bfloat16, bfloat16_t, Prod)
@ -381,6 +600,9 @@ instantiate_same_reduce(min_, int32, int32_t, Min)
instantiate_same_reduce(min_, float16, half, Min)
instantiate_same_reduce(min_, float32, float, Min)
instantiate_same_reduce_no_atomics(min_, int64, int64_t, Min)
instantiate_same_reduce_no_atomics(min_, uint64, uint64_t, Min)
instantiate_same_reduce(max_, uint8, uint8_t, Max)
instantiate_same_reduce(max_, uint16, uint16_t, Max)
instantiate_same_reduce(max_, uint32, uint32_t, Max)
@ -390,5 +612,8 @@ instantiate_same_reduce(max_, int32, int32_t, Max)
instantiate_same_reduce(max_, float16, half, Max)
instantiate_same_reduce(max_, float32, float, Max)
instantiate_same_reduce_no_atomics(max_, int64, int64_t, Max)
instantiate_same_reduce_no_atomics(max_, uint64, uint64_t, Max)
instantiate_same_reduce(min_, bfloat16, bfloat16_t, Min)
instantiate_same_reduce(max_, bfloat16, bfloat16_t, Max)

View File

@ -256,3 +256,21 @@ inline bfloat16_t log1p(bfloat16_t x) {
return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
}
///////////////////////////////////////////////////////////////////////////////
// SIMD shuffle ops
///////////////////////////////////////////////////////////////////////////////
inline uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) {
return as_type<uint64_t>(
metal::simd_shuffle_down(as_type<uint2>(data), delta));
}
inline int64_t simd_shuffle_down(int64_t data, uint16_t delta) {
return as_type<int64_t>(
metal::simd_shuffle_down(as_type<uint2>(data), delta));
}
inline bool simd_shuffle_down(bool data, uint16_t delta) {
return simd_shuffle_down(static_cast<uint32_t>(data), delta);
}

View File

@ -28,35 +28,40 @@ inline auto safe_divup(size_t n, size_t m) {
return safe_div(n, m) * m;
}
inline bool is_64b_int(Dtype dtype) {
return dtype == int64 || dtype == uint64;
}
// All Reduce
void all_reduce_dispatch(
const array& in,
array& out,
const std::string& op_name,
MTL::ComputeCommandEncoder* compute_encoder,
metal::Device& d) {
// Get kernel and encode buffers
size_t in_size = in.size();
auto kernel = d.get_kernel("all_reduce_" + op_name + type_to_name(in));
metal::Device& d,
const Stream& s) {
Dtype out_dtype = out.dtype();
bool is_out_64b_int = is_64b_int(out_dtype);
auto kernel = (is_out_64b_int)
? d.get_kernel("all_reduce_no_atomics_" + op_name + type_to_name(in))
: d.get_kernel("all_reduce_" + op_name + type_to_name(in));
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&in_size, sizeof(size_t), 2);
// Set grid dimensions
// We make sure each thread has enough to do by making it read in
// at least n_reads inputs
int n_reads = REDUCE_N_READS;
size_t in_size = in.size();
// mod_in_size gives us the groups of n_reads needed to go over the entire
// input
uint mod_in_size = (in_size + n_reads - 1) / n_reads;
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
thread_group_size =
mod_in_size > thread_group_size ? thread_group_size : mod_in_size;
uint simd_size = kernel->threadExecutionWidth();
thread_group_size =
((thread_group_size + simd_size - 1) / simd_size) * simd_size;
// If the number of thread groups needed exceeds 1024, we reuse threads groups
uint n_thread_groups = safe_div(mod_in_size, thread_group_size);
@ -66,7 +71,52 @@ void all_reduce_dispatch(
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
// Encode buffers and dispatch
if (is_out_64b_int == false || n_thread_groups == 1) {
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&in_size, sizeof(size_t), 2);
compute_encoder->dispatchThreads(grid_dims, group_dims);
} else {
// Allocate intermediate array to store partial reduction results
size_t intermediate_size = n_thread_groups;
array intermediate =
array({static_cast<int>(intermediate_size)}, out_dtype, nullptr, {});
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
std::vector<array> intermediates = {intermediate};
// First dispatch
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, intermediate, 1);
compute_encoder->setBytes(&in_size, sizeof(size_t), 2);
compute_encoder->dispatchThreads(grid_dims, group_dims);
// Second pass to reduce intermediate reduction results written to DRAM
set_array_buffer(compute_encoder, intermediate, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&intermediate_size, sizeof(size_t), 2);
mod_in_size = (intermediate_size + n_reads - 1) / n_reads;
thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
thread_group_size =
mod_in_size > thread_group_size ? thread_group_size : mod_in_size;
thread_group_size =
((thread_group_size + simd_size - 1) / simd_size) * simd_size;
// If the number of thread groups needed exceeds 1024, we reuse threads
// groups
nthreads = thread_group_size;
group_dims = MTL::Size(thread_group_size, 1, 1);
grid_dims = MTL::Size(nthreads, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[intermediates](MTL::CommandBuffer*) mutable {
intermediates.clear();
});
}
}
void row_reduce_general_dispatch(
@ -76,22 +126,31 @@ void row_reduce_general_dispatch(
const ReductionPlan& plan,
const std::vector<int>& axes,
MTL::ComputeCommandEncoder* compute_encoder,
metal::Device& d) {
auto kernel =
d.get_kernel("row_reduce_general_" + op_name + type_to_name(in));
metal::Device& d,
const Stream& s) {
Dtype out_dtype = out.dtype();
bool is_out_64b_int = is_64b_int(out_dtype);
auto kernel = (is_out_64b_int)
? d.get_kernel(
"row_reduce_general_no_atomics_" + op_name + type_to_name(in))
: d.get_kernel("row_reduce_general_" + op_name + type_to_name(in));
compute_encoder->setComputePipelineState(kernel);
// Prepare the arguments for the kernel
int n_reads = REDUCE_N_READS;
size_t reduction_size = plan.shape.back();
size_t out_size = out.size();
auto shape = plan.shape;
auto strides = plan.strides;
shape.pop_back();
strides.pop_back();
size_t non_row_reductions = 1;
for (auto s : shape) {
non_row_reductions *= static_cast<size_t>(s);
}
size_t out_size = out.size();
auto [rem_shape, rem_strides] = shapes_without_reduction_axes(in, axes);
for (auto s : rem_shape) {
shape.push_back(s);
@ -101,16 +160,6 @@ void row_reduce_general_dispatch(
}
int ndim = shape.size();
// Set the arguments for the kernel
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4);
compute_encoder->setBytes(strides.data(), strides.size() * sizeof(size_t), 5);
compute_encoder->setBytes(&ndim, sizeof(int), 6);
// Each thread group is responsible for 1 output
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
thread_group_size =
@ -127,7 +176,88 @@ void row_reduce_general_dispatch(
MTL::Size grid_dims = MTL::Size(n_threads, non_row_reductions, 1);
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
if (is_out_64b_int == false || non_row_reductions == 1) {
// Set the arguments for the kernel
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4);
compute_encoder->setBytes(
strides.data(), strides.size() * sizeof(size_t), 5);
compute_encoder->setBytes(&ndim, sizeof(int), 6);
compute_encoder->dispatchThreads(grid_dims, group_dims);
} else {
// Allocate intermediate array to store partial reduction results
array intermediate = array(
{static_cast<int>(out.size()), static_cast<int>(non_row_reductions)},
out_dtype,
nullptr,
{});
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
std::vector<array> intermediates = {intermediate};
// Set the arguments for the kernel
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, intermediate, 1);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4);
compute_encoder->setBytes(
strides.data(), strides.size() * sizeof(size_t), 5);
compute_encoder->setBytes(&ndim, sizeof(int), 6);
compute_encoder->dispatchThreads(grid_dims, group_dims);
// Set up second dispatch
reduction_size = non_row_reductions;
out_size = 1;
// Shape of axes that aren't participating in reduction remains unchanged.
std::vector<int> new_shape = rem_shape;
// Update their strides since they'll be different post partial reduction in
// first compute dispatch.
std::vector<size_t> new_strides = rem_strides;
new_strides.back() = reduction_size;
for (int i = new_shape.size() - 2; i >= 0; i--) {
new_strides[i] = new_shape[i + 1] * new_strides[i + 1];
}
ndim = new_shape.size();
// Set the arguments for the kernel
set_array_buffer(compute_encoder, intermediate, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
compute_encoder->setBytes(
new_shape.data(), new_shape.size() * sizeof(int), 4);
compute_encoder->setBytes(
new_strides.data(), new_strides.size() * sizeof(size_t), 5);
compute_encoder->setBytes(&ndim, sizeof(int), 6);
// Each thread group is responsible for 1 output
thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
thread_group_size =
std::min((reduction_size + n_reads - 1) / n_reads, thread_group_size);
// Align thread group size with simd_size
thread_group_size =
(thread_group_size + simd_size - 1) / simd_size * simd_size;
assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup());
// Launch enough thread groups for each output
n_threads = thread_group_size;
grid_dims = MTL::Size(n_threads, out.size(), 1);
group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[intermediates](MTL::CommandBuffer*) mutable {
intermediates.clear();
});
}
}
void strided_reduce_general_dispatch(
@ -137,9 +267,16 @@ void strided_reduce_general_dispatch(
const ReductionPlan& plan,
const std::vector<int>& axes,
MTL::ComputeCommandEncoder* compute_encoder,
metal::Device& d) {
auto kernel =
d.get_kernel("col_reduce_general_" + op_name + type_to_name(in));
metal::Device& d,
const Stream& s) {
Dtype out_dtype = out.dtype();
bool is_out_64b_int = is_64b_int(out_dtype);
auto kernel = (is_out_64b_int)
? d.get_kernel(
"col_reduce_general_no_atomics_" + op_name + type_to_name(in))
: d.get_kernel("col_reduce_general_" + op_name + type_to_name(in));
compute_encoder->setComputePipelineState(kernel);
// Prepare the arguments for the kernel
size_t reduction_size = plan.shape.back();
@ -162,19 +299,7 @@ void strided_reduce_general_dispatch(
}
int ndim = shape.size();
// Set the arguments for the kernel
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3);
compute_encoder->setBytes(&out_size, sizeof(size_t), 4);
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5);
compute_encoder->setBytes(strides.data(), strides.size() * sizeof(size_t), 6);
compute_encoder->setBytes(&ndim, sizeof(int), 7);
// Select block dimensions
// Each thread reads 16 inputs to give it more work
uint n_inputs_per_thread = REDUCE_N_READS;
uint n_threads_per_output =
@ -183,14 +308,22 @@ void strided_reduce_general_dispatch(
// We spread outputs over the x dimension and inputs over the y dimension
// Threads with the same lid.x in a given threadgroup work on the same
// output and each thread in the y dimension accumulates for that output
// Threads with same lid.x, i.e. each column of threads work on same output
uint threadgroup_dim_x = std::min(out_size, 128ul);
// Number of threads along y, is dependent on number of reductions needed.
uint threadgroup_dim_y =
kernel->maxTotalThreadsPerThreadgroup() / threadgroup_dim_x;
threadgroup_dim_y = std::min(n_threads_per_output, threadgroup_dim_y);
// Derive number of thread groups along x, based on how many threads we need
// along x
uint n_threadgroups_x =
(out_size + threadgroup_dim_x - 1) / threadgroup_dim_x;
// Derive number of thread groups along y based on how many threads we need
// along y
uint n_threadgroups_y =
(n_threads_per_output + threadgroup_dim_y - 1) / threadgroup_dim_y;
@ -199,6 +332,18 @@ void strided_reduce_general_dispatch(
MTL::Size(n_threadgroups_x, n_threadgroups_y, non_col_reductions);
MTL::Size group_dims = MTL::Size(threadgroup_dim_x, threadgroup_dim_y, 1);
if (is_out_64b_int == false) {
// Set the arguments for the kernel
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3);
compute_encoder->setBytes(&out_size, sizeof(size_t), 4);
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5);
compute_encoder->setBytes(
strides.data(), strides.size() * sizeof(size_t), 6);
compute_encoder->setBytes(&ndim, sizeof(int), 7);
// We set shared memory to be exploited here for reductions within a
// threadgroup - each thread must be able to update its accumulated output
// Note: Each threadgroup should have 32kB of data in threadgroup memory
@ -209,8 +354,100 @@ void strided_reduce_general_dispatch(
compute_encoder->setThreadgroupMemoryLength(
safe_divup(threadgroup_dim_x * threadgroup_dim_y * out.itemsize(), 16),
0);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
} else {
// Allocate intermediate array to store reduction results from all thread
// groups
array intermediate = array(
{static_cast<int>(out.size()),
static_cast<int>(n_threadgroups_y * non_col_reductions)},
out_dtype,
nullptr,
{});
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
std::vector<array> intermediates = {intermediate};
// Set the arguments for the kernel
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, intermediate, 1);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3);
compute_encoder->setBytes(&out_size, sizeof(size_t), 4);
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5);
compute_encoder->setBytes(
strides.data(), strides.size() * sizeof(size_t), 6);
compute_encoder->setBytes(&ndim, sizeof(int), 7);
// We set shared memory to be exploited here for reductions within a
// threadgroup - each thread must be able to update its accumulated output
// Note: Each threadgroup should have 32kB of data in threadgroup memory
// and threadgroup_dim_x * threadgroup_dim_y <= 1024 by design
// This should be fine for floats, but we might need to revisit
// if we ever come to doubles. In that case, we should also cut
// down the number of threads we launch in a threadgroup
compute_encoder->setThreadgroupMemoryLength(
safe_divup(threadgroup_dim_x * threadgroup_dim_y * out.itemsize(), 16),
0);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
// Perform second pass of reductions
// Reduce results of threadgroups along y, z from first pass, that
// collectively work on each output element.
reduction_size = n_threadgroups_y * non_col_reductions;
out_size = 1;
// Shape of axes that aren't participating in reduction remains unchanged.
std::vector<int> new_shape = rem_shape;
// Update their strides since they'll be different after a partial reduction
// post first compute dispatch.
std::vector<size_t> new_strides = rem_strides;
new_strides.back() = reduction_size;
for (int i = new_shape.size() - 2; i >= 0; i--) {
new_strides[i] = new_shape[i + 1] * new_strides[i + 1];
}
ndim = new_shape.size();
auto row_reduce_kernel = d.get_kernel(
"row_reduce_general_no_atomics_" + op_name +
type_to_name(intermediate));
compute_encoder->setComputePipelineState(row_reduce_kernel);
set_array_buffer(compute_encoder, intermediate, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
compute_encoder->setBytes(
new_shape.data(), new_shape.size() * sizeof(int), 4);
compute_encoder->setBytes(
new_strides.data(), new_strides.size() * sizeof(size_t), 5);
compute_encoder->setBytes(&ndim, sizeof(int), 6);
// Each thread group is responsible for 1 output
size_t n_reads = REDUCE_N_READS;
size_t thread_group_size =
row_reduce_kernel->maxTotalThreadsPerThreadgroup();
thread_group_size =
std::min((reduction_size + n_reads - 1) / n_reads, thread_group_size);
// Align thread group size with simd_size
uint simd_size = row_reduce_kernel->threadExecutionWidth();
thread_group_size =
(thread_group_size + simd_size - 1) / simd_size * simd_size;
assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup());
// Launch enough thread groups for each output
uint n_threads = thread_group_size;
grid_dims = MTL::Size(n_threads, out.size(), 1);
group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[intermediates](MTL::CommandBuffer*) mutable {
intermediates.clear();
});
}
}
} // namespace
@ -223,14 +460,6 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
array in = inputs[0];
// TODO: Allow specific row and column reductions with types disabled
// due to atomics ?
if (size_of(in.dtype()) == 8) {
std::ostringstream msg;
msg << "[Reduce::eval_gpu] Does not support " << in.dtype();
throw std::runtime_error(msg.str());
}
// Make sure no identity reductions trickle down here
assert(!axes_.empty());
@ -297,7 +526,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
// Reducing over everything and the data is all there no broadcasting or
// slicing etc.
if (plan.type == ContiguousAllReduce) {
all_reduce_dispatch(in, out, op_name, compute_encoder, d);
all_reduce_dispatch(in, out, op_name, compute_encoder, d, s);
}
// At least the last dimension is row contiguous and we are reducing over
@ -305,7 +534,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
else if (
plan.type == ContiguousReduce || plan.type == GeneralContiguousReduce) {
row_reduce_general_dispatch(
in, out, op_name, plan, axes_, compute_encoder, d);
in, out, op_name, plan, axes_, compute_encoder, d, s);
}
// At least the last two dimensions are contiguous and we are doing a
@ -314,7 +543,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
plan.type == ContiguousStridedReduce ||
plan.type == GeneralStridedReduce) {
strided_reduce_general_dispatch(
in, out, op_name, plan, axes_, compute_encoder, d);
in, out, op_name, plan, axes_, compute_encoder, d, s);
}
if (!copies.empty()) {

View File

@ -55,6 +55,8 @@ class TestReduce(mlx_tests.MLXTestCase):
"uint8",
"uint16",
"uint32",
"int64",
"uint64",
]
float_dtypes = ["float32"]