Add GPU support for uint64/int64 reductions (#569)

This commit is contained in:
Vijay Krish 2024-01-31 11:18:04 -08:00 committed by GitHub
parent bad67fec37
commit fcc5ac1c64
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 654 additions and 192 deletions

View File

@ -63,18 +63,6 @@ struct ArgMax {
} }
}; };
bool simd_shuffle_down(bool data, uint16_t delta) {
return simd_shuffle_down(static_cast<uint32_t>(data), delta);
}
uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) {
return as_type<uint64_t>(simd_shuffle_down(as_type<uint2>(data), delta));
}
int64_t simd_shuffle_down(int64_t data, uint16_t delta) {
return as_type<int64_t>(simd_shuffle_down(as_type<uint2>(data), delta));
}
template <typename U> template <typename U>
IndexValPair<U> simd_shuffle_down(IndexValPair<U> data, uint16_t delta) { IndexValPair<U> simd_shuffle_down(IndexValPair<U> data, uint16_t delta) {
return IndexValPair<U>( return IndexValPair<U>(

View File

@ -24,11 +24,59 @@ template <typename T, typename Op>
device otype *out [[buffer(1)]], \ device otype *out [[buffer(1)]], \
uint tid [[thread_position_in_grid]]); uint tid [[thread_position_in_grid]]);
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// All reduce // All reduce
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
inline U per_thread_all_reduce(
const device T *in,
const device size_t& in_size,
uint gid,
uint grid_size) {
Op op;
U total_val = Op::init;
if (gid * N_READS < in_size) {
in += gid * N_READS;
int r = 0;
for(; r < (int)ceildiv(in_size, grid_size * N_READS) - 1; r++) {
U vals[N_READS] = {op.init};
for(int i = 0; i < N_READS; i++) {
vals[i] = static_cast<U>(in[i]);
}
for(int i = 0; i < N_READS; i++) {
total_val = op(vals[i], total_val);
}
in += grid_size * N_READS;
}
// Separate case for the last set as we close the reduction size
size_t curr_idx = (gid + r * (size_t)grid_size) * N_READS;
if (curr_idx < in_size) {
int max_reads = in_size - curr_idx;
T vals[N_READS];
for(int i = 0, idx = 0; i < N_READS; i++, idx++) {
idx = idx < max_reads ? idx : max_reads - 1;
vals[i] = in[idx];
}
for(int i = 0; i < N_READS; i++) {
U val = i < max_reads ? vals[i] : Op::init;
total_val = op(static_cast<U>(val), total_val);
}
}
}
return total_val;
}
// NB: This kernel assumes threads_per_threadgroup is at most
// 1024. This way with a simd_size of 32, we are guaranteed to
// complete the reduction in two steps of simd-level reductions.
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS> template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
[[kernel]] void all_reduce( [[kernel]] void all_reduce(
const device T *in [[buffer(0)]], const device T *in [[buffer(0)]],
@ -40,53 +88,18 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
uint simd_per_group [[simdgroups_per_threadgroup]], uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) { uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
// NB: this kernel assumes threads_per_threadgroup is at most
// 1024. This way with a simd_size of 32, we are guaranteed to
// complete the reduction in two steps of simd-level reductions.
Op op; Op op;
threadgroup U local_vals[simd_size]; threadgroup U local_vals[simd_size];
U total_val = Op::init;
in += gid * N_READS; U total_val = per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
int r = 0;
for(; r < (int)ceildiv(in_size, grid_size * N_READS) - 1; r++) {
U vals[N_READS] = {op.init};
for(int i = 0; i < N_READS; i++) {
vals[i] = static_cast<U>(in[i]);
}
for(int i = 0; i < N_READS; i++) {
total_val = op(vals[i], total_val);
}
in += grid_size * N_READS;
}
// Separate case for the last set as we close the reduction size
size_t curr_idx = (gid + r * (size_t)grid_size) * N_READS;
if (curr_idx < in_size) {
int max_reads = in_size - curr_idx;
T vals[N_READS];
for(int i = 0, idx = 0; i < N_READS; i++, idx++) {
idx = idx < max_reads ? idx : max_reads - 1;
vals[i] = in[idx];
}
for(int i = 0; i < N_READS; i++) {
U val = i < max_reads ? vals[i] : Op::init;
total_val = op(static_cast<U>(val), total_val);
}
}
// Reduction within simd group // Reduction within simd group
total_val = op.simd_reduce(total_val); total_val = op.simd_reduce(total_val);
if (simd_lane_id == 0) { if (simd_lane_id == 0) {
local_vals[simd_group_id] = total_val; local_vals[simd_group_id] = total_val;
} }
// Reduction within thread group // Reduction within thread group
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
total_val = lid < simd_per_group ? local_vals[lid] : op.init; total_val = lid < simd_per_group ? local_vals[lid] : op.init;
@ -98,6 +111,46 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
} }
} }
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
[[kernel]] void all_reduce_no_atomics(
const device T *in [[buffer(0)]],
device U *out [[buffer(1)]],
const device size_t& in_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint grid_size [[threads_per_grid]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint thread_group_id [[threadgroup_position_in_grid]]) {
Op op;
threadgroup U local_vals[simd_size];
U total_val = per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
// Reduction within simd group (simd_add isn't supported for uint64/int64 types)
for (uint16_t lane_offset = simd_size/2; lane_offset > 0; lane_offset /= 2) {
total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
}
// Write simd group reduction results to local memory
if (simd_lane_id == 0) {
local_vals[simd_group_id] = total_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Reduction of simdgroup reduction results within threadgroup.
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
for (uint16_t lane_offset = simd_size/2; lane_offset > 0; lane_offset /= 2) {
total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
}
// Reduction across threadgroups
if (lid == 0) {
out[thread_group_id] = total_val;
}
}
#define instantiate_all_reduce(name, itype, otype, op) \ #define instantiate_all_reduce(name, itype, otype, op) \
template [[host_name("all_reduce_" #name)]] \ template [[host_name("all_reduce_" #name)]] \
[[kernel]] void all_reduce<itype, otype, op>( \ [[kernel]] void all_reduce<itype, otype, op>( \
@ -111,11 +164,80 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
uint simd_lane_id [[thread_index_in_simdgroup]], \ uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_all_reduce_no_atomics(name, itype, otype, op) \
template [[host_name("all_reduce_no_atomics_" #name)]] \
[[kernel]] void all_reduce_no_atomics<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device otype *out [[buffer(1)]], \
const device size_t& in_size [[buffer(2)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint grid_size [[threads_per_grid]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint thread_group_id [[threadgroup_position_in_grid]]);
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Row atomics // Row atomics
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
inline U per_thread_row_reduce(
const device T *in,
const constant size_t& reduction_size,
const constant size_t& out_size,
const constant int* shape,
const constant size_t* strides,
const constant int& ndim,
uint lsize_x,
uint lid_x,
uint2 tid) {
Op op;
// Each threadgroup handles 1 reduction
// TODO: Specializing elem_to_loc would be slightly faster
int idx = tid.y * out_size + tid.x;
int extra_offset = elem_to_loc(idx, shape, strides, ndim);
in += extra_offset + lid_x * N_READS;
// The reduction is accumulated here
U total_val = Op::init;
// Loop over the reduction size within thread group
int r = 0;
for (; r < (int)ceildiv(reduction_size, N_READS*lsize_x) - 1; r++) {
T vals[N_READS];
for(int i = 0; i < N_READS; i++) {
vals[i] = in[i];
}
for(int i = 0; i < N_READS; i++) {
total_val = op(static_cast<U>(vals[i]), total_val);
}
in += lsize_x * N_READS;
}
// Separate case for the last set as we close the reduction size
size_t reduction_index = (lid_x + (size_t)lsize_x * r) * N_READS;
if(reduction_index < reduction_size) {
int max_reads = reduction_size - reduction_index;
T vals[N_READS];
for(int i = 0; i < N_READS; i++) {
int idx = min(i, max_reads - 1);
vals[i] = static_cast<U>(in[idx]);
}
for(int i = 0; i < N_READS; i++) {
T val = i < max_reads ? vals[i] : Op::init;
total_val = op(static_cast<U>(val), total_val);
}
}
return total_val;
}
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS> template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
[[kernel]] void row_reduce_general( [[kernel]] void row_reduce_general(
const device T *in [[buffer(0)]], const device T *in [[buffer(0)]],
@ -133,46 +255,9 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
uint simd_group_id [[simdgroup_index_in_threadgroup]]) { uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
Op op; Op op;
// Each threadgroup handles 1 reduction
// TODO: Specializing elem_to_loc would be slightly faster
int idx = tid.y * out_size + tid.x;
int extra_offset = elem_to_loc(idx, shape, strides, ndim);
in += extra_offset + lid.x * N_READS;
// The reduction is accumulated here
U total_val = Op::init;
threadgroup U local_vals[simd_size]; threadgroup U local_vals[simd_size];
// Loop over the reduction size within thread group U total_val = per_thread_row_reduce<T, U, Op, N_READS>(in, reduction_size, out_size, shape, strides, ndim, lsize.x, lid.x, tid.xy);
int r = 0;
for (; r < (int)ceildiv(reduction_size, N_READS*lsize.x) - 1; r++) {
T vals[N_READS];
for(int i = 0; i < N_READS; i++) {
vals[i] = in[i];
}
for(int i = 0; i < N_READS; i++) {
total_val = op(static_cast<U>(vals[i]), total_val);
}
in += lsize.x * N_READS;
}
// Separate case for the last set as we close the reduction size
size_t reduction_index = (lid.x + (size_t)lsize.x * r) * N_READS;
if(reduction_index < reduction_size) {
int max_reads = reduction_size - reduction_index;
T vals[N_READS];
for(int i = 0; i < N_READS; i++) {
int idx = min(i, max_reads - 1);
vals[i] = static_cast<U>(in[idx]);
}
for(int i = 0; i < N_READS; i++) {
T val = i < max_reads ? vals[i] : Op::init;
total_val = op(static_cast<U>(val), total_val);
}
}
total_val = op.simd_reduce(total_val); total_val = op.simd_reduce(total_val);
@ -194,6 +279,53 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
} }
} }
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
[[kernel]] void row_reduce_general_no_atomics(
const device T *in [[buffer(0)]],
device U *out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant int* shape [[buffer(4)]],
const constant size_t* strides [[buffer(5)]],
const constant int& ndim [[buffer(6)]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]],
uint3 gsize [[threads_per_grid]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
Op op;
threadgroup U local_vals[simd_size];
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(in, reduction_size, out_size, shape, strides, ndim, lsize.x, lid.x, tid.xy);
// Reduction within simd group - simd_add isn't supported for int64 types
for (uint16_t i = simd_size/2; i > 0; i /= 2) {
total_val = op(total_val, simd_shuffle_down(total_val, i));
}
// Prepare next level
if (simd_lane_id == 0) {
local_vals[simd_group_id] = total_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Reduction within thread group
// Only needed if thread group has multiple simd groups
if(ceildiv(reduction_size, N_READS) > simd_size) {
total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init;
for (uint16_t i = simd_size/2; i > 0; i /= 2) {
total_val = op(total_val, simd_shuffle_down(total_val, i));
}
}
// Write row reduce output for threadgroup with 1st thread in thread group
if (lid.x == 0) {
out[(ceildiv(gsize.y, lsize.y) * tid.x) + tid.y] = total_val;
}
}
#define instantiate_row_reduce_general(name, itype, otype, op) \ #define instantiate_row_reduce_general(name, itype, otype, op) \
template [[host_name("row_reduce_general_" #name)]] \ template [[host_name("row_reduce_general_" #name)]] \
[[kernel]] void row_reduce_general<itype, otype, op>( \ [[kernel]] void row_reduce_general<itype, otype, op>( \
@ -211,52 +343,59 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
uint simd_per_group [[simdgroups_per_threadgroup]], \ uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \
template [[host_name("row_reduce_general_no_atomics_" #name)]] \
[[kernel]] void row_reduce_general_no_atomics<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device otype *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant int* shape [[buffer(4)]], \
const constant size_t* strides [[buffer(5)]], \
const constant int& ndim [[buffer(6)]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 gsize [[threads_per_grid]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Column reduce // Column reduce
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS> template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
inline void _contiguous_strided_reduce( inline U _contiguous_strided_reduce(
const device T *in, const device T *in,
device mlx_atomic<U> *out, threadgroup U *local_data,
threadgroup U *local_data, uint in_idx,
uint in_idx, uint reduction_size,
uint out_idx, uint reduction_stride,
uint reduction_size, uint2 tid,
uint reduction_stride, uint2 lid,
uint2 tid,
uint2 lid,
uint2 lsize) { uint2 lsize) {
Op op; Op op;
T local_vals[N_READS]; U total_val = Op::init;
uint base_offset = (tid.y * lsize.y + lid.y) * N_READS; uint base_offset = (tid.y * lsize.y + lid.y) * N_READS;
for(uint r = 0; r < N_READS; r++) {
uint offset = base_offset + r;
offset = offset < reduction_size ? offset : reduction_size - 1;
local_vals[r] = in[in_idx + offset * reduction_stride];
}
U total_val = Op::init;
for(uint r = 0; r < N_READS && (base_offset + r) < reduction_size; r++) { for(uint r = 0; r < N_READS && (base_offset + r) < reduction_size; r++) {
total_val = op(static_cast<U>(total_val), local_vals[r]); uint offset = base_offset + r;
total_val = op(static_cast<U>(total_val), in[in_idx + offset * reduction_stride]);
} }
local_data[lsize.y * lid.x + lid.y] = total_val; local_data[lsize.y * lid.x + lid.y] = total_val;
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
U val = Op::init;
if(lid.y == 0) { if(lid.y == 0) {
U val = op.init; // Perform reduction across columns in thread group
for(uint i = 0; i < lsize.y; i++) { for(uint i = 0; i < lsize.y; i++) {
val = op(val, local_data[lsize.y * lid.x + i]); val = op(val, local_data[lsize.y * lid.x + i]);
} }
op.atomic_update(out, val, out_idx);
} }
return val;
} }
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS> template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
@ -265,13 +404,13 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
device mlx_atomic<U> *out [[buffer(1)]], device mlx_atomic<U> *out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]], const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]], const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]], const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]], const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]], const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]], const constant int& ndim [[buffer(7)]],
threadgroup U *local_data [[threadgroup(0)]], threadgroup U *local_data [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]], uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]]) { uint3 lsize [[threads_per_threadgroup]]) {
auto out_idx = tid.x * lsize.x + lid.x; auto out_idx = tid.x * lsize.x + lid.x;
auto in_idx = elem_to_loc( auto in_idx = elem_to_loc(
@ -281,18 +420,66 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
ndim ndim
); );
Op op;
if(out_idx < out_size) { if(out_idx < out_size) {
_contiguous_strided_reduce<T, U, Op, N_READS>( U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
in, in,
out, local_data,
local_data, in_idx,
in_idx, reduction_size,
out_idx, reduction_stride,
reduction_size, tid.xy,
reduction_stride, lid.xy,
tid.xy, lsize.xy);
lid.xy,
lsize.xy); // Write out reduction results generated by threadgroups working on specific output element, contiguously.
if (lid.y == 0) {
op.atomic_update(out, val, out_idx);
}
}
}
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
[[kernel]] void col_reduce_general_no_atomics(
const device T *in [[buffer(0)]],
device U *out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
threadgroup U *local_data [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 gid [[thread_position_in_grid]],
uint3 lsize [[threads_per_threadgroup]],
uint3 gsize [[threads_per_grid]]) {
auto out_idx = tid.x * lsize.x + lid.x;
auto in_idx = elem_to_loc(
out_idx + tid.z * out_size,
shape,
strides,
ndim
);
if(out_idx < out_size) {
U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
in,
local_data,
in_idx,
reduction_size,
reduction_stride,
tid.xy,
lid.xy,
lsize.xy);
// Write out reduction results generated by threadgroups working on specific output element, contiguously.
if (lid.y == 0) {
uint tgsize_y = ceildiv(gsize.y, lsize.y);
uint tgsize_z = ceildiv(gsize.z, lsize.z);
out[tgsize_y * tgsize_z * gid.x + tgsize_y * tid.z + tid.y] = val;
}
} }
} }
@ -312,6 +499,23 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
uint3 lid [[thread_position_in_threadgroup]], \ uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]]); uint3 lsize [[threads_per_threadgroup]]);
#define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \
template [[host_name("col_reduce_general_no_atomics_" #name)]] \
[[kernel]] void col_reduce_general_no_atomics<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device otype *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
threadgroup otype *local_data [[threadgroup(0)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 gid [[thread_position_in_grid]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 gsize [[threads_per_grid]]);
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Instantiations // Instantiations
@ -322,6 +526,15 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
instantiate_row_reduce_general(name, itype, otype, op) \ instantiate_row_reduce_general(name, itype, otype, op) \
instantiate_col_reduce_general(name, itype, otype, op) instantiate_col_reduce_general(name, itype, otype, op)
#define instantiate_reduce_no_atomics(name, itype, otype, op) \
instantiate_all_reduce_no_atomics(name, itype, otype, op) \
instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \
instantiate_col_reduce_general_no_atomics(name, itype, otype, op)
#define instantiate_same_reduce_no_atomics(name, tname, type, op) \
instantiate_init_reduce(name ##tname, type, op<type>) \
instantiate_reduce_no_atomics(name ##tname, type, type, op<type>)
#define instantiate_same_reduce(name, tname, type, op) \ #define instantiate_same_reduce(name, tname, type, op) \
instantiate_init_reduce(name ##tname, type, op<type>) \ instantiate_init_reduce(name ##tname, type, op<type>) \
instantiate_reduce(name ##tname, type, type, op<type>) instantiate_reduce(name ##tname, type, type, op<type>)
@ -353,6 +566,9 @@ instantiate_same_reduce(sum, int32, int32_t, Sum)
instantiate_same_reduce(sum, float16, half, Sum) instantiate_same_reduce(sum, float16, half, Sum)
instantiate_same_reduce(sum, float32, float, Sum) instantiate_same_reduce(sum, float32, float, Sum)
instantiate_same_reduce_no_atomics(sum, int64, int64_t, Sum)
instantiate_same_reduce_no_atomics(sum, uint64, uint64_t, Sum)
instantiate_same_reduce(prod, uint8, uint8_t, Prod) instantiate_same_reduce(prod, uint8, uint8_t, Prod)
instantiate_same_reduce(prod, uint16, uint16_t, Prod) instantiate_same_reduce(prod, uint16, uint16_t, Prod)
instantiate_same_reduce(prod, uint32, uint32_t, Prod) instantiate_same_reduce(prod, uint32, uint32_t, Prod)
@ -362,6 +578,9 @@ instantiate_same_reduce(prod, int32, int32_t, Prod)
instantiate_same_reduce(prod, float16, half, Prod) instantiate_same_reduce(prod, float16, half, Prod)
instantiate_same_reduce(prod, float32, float, Prod) instantiate_same_reduce(prod, float32, float, Prod)
instantiate_same_reduce_no_atomics(prod, int64, int64_t, Prod)
instantiate_same_reduce_no_atomics(prod, uint64, uint64_t, Prod)
instantiate_same_reduce(sum, bfloat16, bfloat16_t, Sum) instantiate_same_reduce(sum, bfloat16, bfloat16_t, Sum)
instantiate_same_reduce(prod, bfloat16, bfloat16_t, Prod) instantiate_same_reduce(prod, bfloat16, bfloat16_t, Prod)
@ -381,6 +600,9 @@ instantiate_same_reduce(min_, int32, int32_t, Min)
instantiate_same_reduce(min_, float16, half, Min) instantiate_same_reduce(min_, float16, half, Min)
instantiate_same_reduce(min_, float32, float, Min) instantiate_same_reduce(min_, float32, float, Min)
instantiate_same_reduce_no_atomics(min_, int64, int64_t, Min)
instantiate_same_reduce_no_atomics(min_, uint64, uint64_t, Min)
instantiate_same_reduce(max_, uint8, uint8_t, Max) instantiate_same_reduce(max_, uint8, uint8_t, Max)
instantiate_same_reduce(max_, uint16, uint16_t, Max) instantiate_same_reduce(max_, uint16, uint16_t, Max)
instantiate_same_reduce(max_, uint32, uint32_t, Max) instantiate_same_reduce(max_, uint32, uint32_t, Max)
@ -390,5 +612,8 @@ instantiate_same_reduce(max_, int32, int32_t, Max)
instantiate_same_reduce(max_, float16, half, Max) instantiate_same_reduce(max_, float16, half, Max)
instantiate_same_reduce(max_, float32, float, Max) instantiate_same_reduce(max_, float32, float, Max)
instantiate_same_reduce_no_atomics(max_, int64, int64_t, Max)
instantiate_same_reduce_no_atomics(max_, uint64, uint64_t, Max)
instantiate_same_reduce(min_, bfloat16, bfloat16_t, Min) instantiate_same_reduce(min_, bfloat16, bfloat16_t, Min)
instantiate_same_reduce(max_, bfloat16, bfloat16_t, Max) instantiate_same_reduce(max_, bfloat16, bfloat16_t, Max)

View File

@ -256,3 +256,21 @@ inline bfloat16_t log1p(bfloat16_t x) {
return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f))); return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
} }
///////////////////////////////////////////////////////////////////////////////
// SIMD shuffle ops
///////////////////////////////////////////////////////////////////////////////
inline uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) {
return as_type<uint64_t>(
metal::simd_shuffle_down(as_type<uint2>(data), delta));
}
inline int64_t simd_shuffle_down(int64_t data, uint16_t delta) {
return as_type<int64_t>(
metal::simd_shuffle_down(as_type<uint2>(data), delta));
}
inline bool simd_shuffle_down(bool data, uint16_t delta) {
return simd_shuffle_down(static_cast<uint32_t>(data), delta);
}

View File

@ -28,35 +28,40 @@ inline auto safe_divup(size_t n, size_t m) {
return safe_div(n, m) * m; return safe_div(n, m) * m;
} }
inline bool is_64b_int(Dtype dtype) {
return dtype == int64 || dtype == uint64;
}
// All Reduce // All Reduce
void all_reduce_dispatch( void all_reduce_dispatch(
const array& in, const array& in,
array& out, array& out,
const std::string& op_name, const std::string& op_name,
MTL::ComputeCommandEncoder* compute_encoder, MTL::ComputeCommandEncoder* compute_encoder,
metal::Device& d) { metal::Device& d,
// Get kernel and encode buffers const Stream& s) {
size_t in_size = in.size(); Dtype out_dtype = out.dtype();
auto kernel = d.get_kernel("all_reduce_" + op_name + type_to_name(in)); bool is_out_64b_int = is_64b_int(out_dtype);
auto kernel = (is_out_64b_int)
? d.get_kernel("all_reduce_no_atomics_" + op_name + type_to_name(in))
: d.get_kernel("all_reduce_" + op_name + type_to_name(in));
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&in_size, sizeof(size_t), 2);
// Set grid dimensions
// We make sure each thread has enough to do by making it read in // We make sure each thread has enough to do by making it read in
// at least n_reads inputs // at least n_reads inputs
int n_reads = REDUCE_N_READS; int n_reads = REDUCE_N_READS;
size_t in_size = in.size();
// mod_in_size gives us the groups of n_reads needed to go over the entire // mod_in_size gives us the groups of n_reads needed to go over the entire
// input // input
uint mod_in_size = (in_size + n_reads - 1) / n_reads; uint mod_in_size = (in_size + n_reads - 1) / n_reads;
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
thread_group_size = thread_group_size =
mod_in_size > thread_group_size ? thread_group_size : mod_in_size; mod_in_size > thread_group_size ? thread_group_size : mod_in_size;
uint simd_size = kernel->threadExecutionWidth();
thread_group_size =
((thread_group_size + simd_size - 1) / simd_size) * simd_size;
// If the number of thread groups needed exceeds 1024, we reuse threads groups // If the number of thread groups needed exceeds 1024, we reuse threads groups
uint n_thread_groups = safe_div(mod_in_size, thread_group_size); uint n_thread_groups = safe_div(mod_in_size, thread_group_size);
@ -66,7 +71,52 @@ void all_reduce_dispatch(
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims); // Encode buffers and dispatch
if (is_out_64b_int == false || n_thread_groups == 1) {
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&in_size, sizeof(size_t), 2);
compute_encoder->dispatchThreads(grid_dims, group_dims);
} else {
// Allocate intermediate array to store partial reduction results
size_t intermediate_size = n_thread_groups;
array intermediate =
array({static_cast<int>(intermediate_size)}, out_dtype, nullptr, {});
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
std::vector<array> intermediates = {intermediate};
// First dispatch
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, intermediate, 1);
compute_encoder->setBytes(&in_size, sizeof(size_t), 2);
compute_encoder->dispatchThreads(grid_dims, group_dims);
// Second pass to reduce intermediate reduction results written to DRAM
set_array_buffer(compute_encoder, intermediate, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&intermediate_size, sizeof(size_t), 2);
mod_in_size = (intermediate_size + n_reads - 1) / n_reads;
thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
thread_group_size =
mod_in_size > thread_group_size ? thread_group_size : mod_in_size;
thread_group_size =
((thread_group_size + simd_size - 1) / simd_size) * simd_size;
// If the number of thread groups needed exceeds 1024, we reuse threads
// groups
nthreads = thread_group_size;
group_dims = MTL::Size(thread_group_size, 1, 1);
grid_dims = MTL::Size(nthreads, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[intermediates](MTL::CommandBuffer*) mutable {
intermediates.clear();
});
}
} }
void row_reduce_general_dispatch( void row_reduce_general_dispatch(
@ -76,22 +126,31 @@ void row_reduce_general_dispatch(
const ReductionPlan& plan, const ReductionPlan& plan,
const std::vector<int>& axes, const std::vector<int>& axes,
MTL::ComputeCommandEncoder* compute_encoder, MTL::ComputeCommandEncoder* compute_encoder,
metal::Device& d) { metal::Device& d,
auto kernel = const Stream& s) {
d.get_kernel("row_reduce_general_" + op_name + type_to_name(in)); Dtype out_dtype = out.dtype();
bool is_out_64b_int = is_64b_int(out_dtype);
auto kernel = (is_out_64b_int)
? d.get_kernel(
"row_reduce_general_no_atomics_" + op_name + type_to_name(in))
: d.get_kernel("row_reduce_general_" + op_name + type_to_name(in));
compute_encoder->setComputePipelineState(kernel);
// Prepare the arguments for the kernel // Prepare the arguments for the kernel
int n_reads = REDUCE_N_READS; int n_reads = REDUCE_N_READS;
size_t reduction_size = plan.shape.back(); size_t reduction_size = plan.shape.back();
size_t out_size = out.size();
auto shape = plan.shape; auto shape = plan.shape;
auto strides = plan.strides; auto strides = plan.strides;
shape.pop_back(); shape.pop_back();
strides.pop_back(); strides.pop_back();
size_t non_row_reductions = 1; size_t non_row_reductions = 1;
for (auto s : shape) { for (auto s : shape) {
non_row_reductions *= static_cast<size_t>(s); non_row_reductions *= static_cast<size_t>(s);
} }
size_t out_size = out.size();
auto [rem_shape, rem_strides] = shapes_without_reduction_axes(in, axes); auto [rem_shape, rem_strides] = shapes_without_reduction_axes(in, axes);
for (auto s : rem_shape) { for (auto s : rem_shape) {
shape.push_back(s); shape.push_back(s);
@ -101,16 +160,6 @@ void row_reduce_general_dispatch(
} }
int ndim = shape.size(); int ndim = shape.size();
// Set the arguments for the kernel
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4);
compute_encoder->setBytes(strides.data(), strides.size() * sizeof(size_t), 5);
compute_encoder->setBytes(&ndim, sizeof(int), 6);
// Each thread group is responsible for 1 output // Each thread group is responsible for 1 output
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
thread_group_size = thread_group_size =
@ -127,7 +176,88 @@ void row_reduce_general_dispatch(
MTL::Size grid_dims = MTL::Size(n_threads, non_row_reductions, 1); MTL::Size grid_dims = MTL::Size(n_threads, non_row_reductions, 1);
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims); if (is_out_64b_int == false || non_row_reductions == 1) {
// Set the arguments for the kernel
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4);
compute_encoder->setBytes(
strides.data(), strides.size() * sizeof(size_t), 5);
compute_encoder->setBytes(&ndim, sizeof(int), 6);
compute_encoder->dispatchThreads(grid_dims, group_dims);
} else {
// Allocate intermediate array to store partial reduction results
array intermediate = array(
{static_cast<int>(out.size()), static_cast<int>(non_row_reductions)},
out_dtype,
nullptr,
{});
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
std::vector<array> intermediates = {intermediate};
// Set the arguments for the kernel
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, intermediate, 1);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4);
compute_encoder->setBytes(
strides.data(), strides.size() * sizeof(size_t), 5);
compute_encoder->setBytes(&ndim, sizeof(int), 6);
compute_encoder->dispatchThreads(grid_dims, group_dims);
// Set up second dispatch
reduction_size = non_row_reductions;
out_size = 1;
// Shape of axes that aren't participating in reduction remains unchanged.
std::vector<int> new_shape = rem_shape;
// Update their strides since they'll be different post partial reduction in
// first compute dispatch.
std::vector<size_t> new_strides = rem_strides;
new_strides.back() = reduction_size;
for (int i = new_shape.size() - 2; i >= 0; i--) {
new_strides[i] = new_shape[i + 1] * new_strides[i + 1];
}
ndim = new_shape.size();
// Set the arguments for the kernel
set_array_buffer(compute_encoder, intermediate, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
compute_encoder->setBytes(
new_shape.data(), new_shape.size() * sizeof(int), 4);
compute_encoder->setBytes(
new_strides.data(), new_strides.size() * sizeof(size_t), 5);
compute_encoder->setBytes(&ndim, sizeof(int), 6);
// Each thread group is responsible for 1 output
thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
thread_group_size =
std::min((reduction_size + n_reads - 1) / n_reads, thread_group_size);
// Align thread group size with simd_size
thread_group_size =
(thread_group_size + simd_size - 1) / simd_size * simd_size;
assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup());
// Launch enough thread groups for each output
n_threads = thread_group_size;
grid_dims = MTL::Size(n_threads, out.size(), 1);
group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[intermediates](MTL::CommandBuffer*) mutable {
intermediates.clear();
});
}
} }
void strided_reduce_general_dispatch( void strided_reduce_general_dispatch(
@ -137,9 +267,16 @@ void strided_reduce_general_dispatch(
const ReductionPlan& plan, const ReductionPlan& plan,
const std::vector<int>& axes, const std::vector<int>& axes,
MTL::ComputeCommandEncoder* compute_encoder, MTL::ComputeCommandEncoder* compute_encoder,
metal::Device& d) { metal::Device& d,
auto kernel = const Stream& s) {
d.get_kernel("col_reduce_general_" + op_name + type_to_name(in)); Dtype out_dtype = out.dtype();
bool is_out_64b_int = is_64b_int(out_dtype);
auto kernel = (is_out_64b_int)
? d.get_kernel(
"col_reduce_general_no_atomics_" + op_name + type_to_name(in))
: d.get_kernel("col_reduce_general_" + op_name + type_to_name(in));
compute_encoder->setComputePipelineState(kernel);
// Prepare the arguments for the kernel // Prepare the arguments for the kernel
size_t reduction_size = plan.shape.back(); size_t reduction_size = plan.shape.back();
@ -162,19 +299,7 @@ void strided_reduce_general_dispatch(
} }
int ndim = shape.size(); int ndim = shape.size();
// Set the arguments for the kernel
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3);
compute_encoder->setBytes(&out_size, sizeof(size_t), 4);
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5);
compute_encoder->setBytes(strides.data(), strides.size() * sizeof(size_t), 6);
compute_encoder->setBytes(&ndim, sizeof(int), 7);
// Select block dimensions // Select block dimensions
// Each thread reads 16 inputs to give it more work // Each thread reads 16 inputs to give it more work
uint n_inputs_per_thread = REDUCE_N_READS; uint n_inputs_per_thread = REDUCE_N_READS;
uint n_threads_per_output = uint n_threads_per_output =
@ -183,14 +308,22 @@ void strided_reduce_general_dispatch(
// We spread outputs over the x dimension and inputs over the y dimension // We spread outputs over the x dimension and inputs over the y dimension
// Threads with the same lid.x in a given threadgroup work on the same // Threads with the same lid.x in a given threadgroup work on the same
// output and each thread in the y dimension accumulates for that output // output and each thread in the y dimension accumulates for that output
// Threads with same lid.x, i.e. each column of threads work on same output
uint threadgroup_dim_x = std::min(out_size, 128ul); uint threadgroup_dim_x = std::min(out_size, 128ul);
// Number of threads along y, is dependent on number of reductions needed.
uint threadgroup_dim_y = uint threadgroup_dim_y =
kernel->maxTotalThreadsPerThreadgroup() / threadgroup_dim_x; kernel->maxTotalThreadsPerThreadgroup() / threadgroup_dim_x;
threadgroup_dim_y = std::min(n_threads_per_output, threadgroup_dim_y); threadgroup_dim_y = std::min(n_threads_per_output, threadgroup_dim_y);
// Derive number of thread groups along x, based on how many threads we need
// along x
uint n_threadgroups_x = uint n_threadgroups_x =
(out_size + threadgroup_dim_x - 1) / threadgroup_dim_x; (out_size + threadgroup_dim_x - 1) / threadgroup_dim_x;
// Derive number of thread groups along y based on how many threads we need
// along y
uint n_threadgroups_y = uint n_threadgroups_y =
(n_threads_per_output + threadgroup_dim_y - 1) / threadgroup_dim_y; (n_threads_per_output + threadgroup_dim_y - 1) / threadgroup_dim_y;
@ -199,18 +332,122 @@ void strided_reduce_general_dispatch(
MTL::Size(n_threadgroups_x, n_threadgroups_y, non_col_reductions); MTL::Size(n_threadgroups_x, n_threadgroups_y, non_col_reductions);
MTL::Size group_dims = MTL::Size(threadgroup_dim_x, threadgroup_dim_y, 1); MTL::Size group_dims = MTL::Size(threadgroup_dim_x, threadgroup_dim_y, 1);
// We set shared memory to be exploited here for reductions within a if (is_out_64b_int == false) {
// threadgroup - each thread must be able to update its accumulated output // Set the arguments for the kernel
// Note: Each threadgroup should have 32kB of data in threadgroup memory set_array_buffer(compute_encoder, in, 0);
// and threadgroup_dim_x * threadgroup_dim_y <= 1024 by design set_array_buffer(compute_encoder, out, 1);
// This should be fine for floats, but we might need to revisit compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
// if we ever come to doubles. In that case, we should also cut compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3);
// down the number of threads we launch in a threadgroup compute_encoder->setBytes(&out_size, sizeof(size_t), 4);
compute_encoder->setThreadgroupMemoryLength( compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5);
safe_divup(threadgroup_dim_x * threadgroup_dim_y * out.itemsize(), 16), compute_encoder->setBytes(
0); strides.data(), strides.size() * sizeof(size_t), 6);
compute_encoder->setBytes(&ndim, sizeof(int), 7);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims); // We set shared memory to be exploited here for reductions within a
// threadgroup - each thread must be able to update its accumulated output
// Note: Each threadgroup should have 32kB of data in threadgroup memory
// and threadgroup_dim_x * threadgroup_dim_y <= 1024 by design
// This should be fine for floats, but we might need to revisit
// if we ever come to doubles. In that case, we should also cut
// down the number of threads we launch in a threadgroup
compute_encoder->setThreadgroupMemoryLength(
safe_divup(threadgroup_dim_x * threadgroup_dim_y * out.itemsize(), 16),
0);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
} else {
// Allocate intermediate array to store reduction results from all thread
// groups
array intermediate = array(
{static_cast<int>(out.size()),
static_cast<int>(n_threadgroups_y * non_col_reductions)},
out_dtype,
nullptr,
{});
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
std::vector<array> intermediates = {intermediate};
// Set the arguments for the kernel
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, intermediate, 1);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3);
compute_encoder->setBytes(&out_size, sizeof(size_t), 4);
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5);
compute_encoder->setBytes(
strides.data(), strides.size() * sizeof(size_t), 6);
compute_encoder->setBytes(&ndim, sizeof(int), 7);
// We set shared memory to be exploited here for reductions within a
// threadgroup - each thread must be able to update its accumulated output
// Note: Each threadgroup should have 32kB of data in threadgroup memory
// and threadgroup_dim_x * threadgroup_dim_y <= 1024 by design
// This should be fine for floats, but we might need to revisit
// if we ever come to doubles. In that case, we should also cut
// down the number of threads we launch in a threadgroup
compute_encoder->setThreadgroupMemoryLength(
safe_divup(threadgroup_dim_x * threadgroup_dim_y * out.itemsize(), 16),
0);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
// Perform second pass of reductions
// Reduce results of threadgroups along y, z from first pass, that
// collectively work on each output element.
reduction_size = n_threadgroups_y * non_col_reductions;
out_size = 1;
// Shape of axes that aren't participating in reduction remains unchanged.
std::vector<int> new_shape = rem_shape;
// Update their strides since they'll be different after a partial reduction
// post first compute dispatch.
std::vector<size_t> new_strides = rem_strides;
new_strides.back() = reduction_size;
for (int i = new_shape.size() - 2; i >= 0; i--) {
new_strides[i] = new_shape[i + 1] * new_strides[i + 1];
}
ndim = new_shape.size();
auto row_reduce_kernel = d.get_kernel(
"row_reduce_general_no_atomics_" + op_name +
type_to_name(intermediate));
compute_encoder->setComputePipelineState(row_reduce_kernel);
set_array_buffer(compute_encoder, intermediate, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
compute_encoder->setBytes(
new_shape.data(), new_shape.size() * sizeof(int), 4);
compute_encoder->setBytes(
new_strides.data(), new_strides.size() * sizeof(size_t), 5);
compute_encoder->setBytes(&ndim, sizeof(int), 6);
// Each thread group is responsible for 1 output
size_t n_reads = REDUCE_N_READS;
size_t thread_group_size =
row_reduce_kernel->maxTotalThreadsPerThreadgroup();
thread_group_size =
std::min((reduction_size + n_reads - 1) / n_reads, thread_group_size);
// Align thread group size with simd_size
uint simd_size = row_reduce_kernel->threadExecutionWidth();
thread_group_size =
(thread_group_size + simd_size - 1) / simd_size * simd_size;
assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup());
// Launch enough thread groups for each output
uint n_threads = thread_group_size;
grid_dims = MTL::Size(n_threads, out.size(), 1);
group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[intermediates](MTL::CommandBuffer*) mutable {
intermediates.clear();
});
}
} }
} // namespace } // namespace
@ -223,14 +460,6 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
array in = inputs[0]; array in = inputs[0];
// TODO: Allow specific row and column reductions with types disabled
// due to atomics ?
if (size_of(in.dtype()) == 8) {
std::ostringstream msg;
msg << "[Reduce::eval_gpu] Does not support " << in.dtype();
throw std::runtime_error(msg.str());
}
// Make sure no identity reductions trickle down here // Make sure no identity reductions trickle down here
assert(!axes_.empty()); assert(!axes_.empty());
@ -297,7 +526,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
// Reducing over everything and the data is all there no broadcasting or // Reducing over everything and the data is all there no broadcasting or
// slicing etc. // slicing etc.
if (plan.type == ContiguousAllReduce) { if (plan.type == ContiguousAllReduce) {
all_reduce_dispatch(in, out, op_name, compute_encoder, d); all_reduce_dispatch(in, out, op_name, compute_encoder, d, s);
} }
// At least the last dimension is row contiguous and we are reducing over // At least the last dimension is row contiguous and we are reducing over
@ -305,7 +534,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
else if ( else if (
plan.type == ContiguousReduce || plan.type == GeneralContiguousReduce) { plan.type == ContiguousReduce || plan.type == GeneralContiguousReduce) {
row_reduce_general_dispatch( row_reduce_general_dispatch(
in, out, op_name, plan, axes_, compute_encoder, d); in, out, op_name, plan, axes_, compute_encoder, d, s);
} }
// At least the last two dimensions are contiguous and we are doing a // At least the last two dimensions are contiguous and we are doing a
@ -314,7 +543,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
plan.type == ContiguousStridedReduce || plan.type == ContiguousStridedReduce ||
plan.type == GeneralStridedReduce) { plan.type == GeneralStridedReduce) {
strided_reduce_general_dispatch( strided_reduce_general_dispatch(
in, out, op_name, plan, axes_, compute_encoder, d); in, out, op_name, plan, axes_, compute_encoder, d, s);
} }
if (!copies.empty()) { if (!copies.empty()) {

View File

@ -55,6 +55,8 @@ class TestReduce(mlx_tests.MLXTestCase):
"uint8", "uint8",
"uint16", "uint16",
"uint32", "uint32",
"int64",
"uint64",
] ]
float_dtypes = ["float32"] float_dtypes = ["float32"]