mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-23 18:11:17 +08:00
Add GPU support for uint64/int64 reductions (#569)
This commit is contained in:
parent
bad67fec37
commit
fcc5ac1c64
@ -63,18 +63,6 @@ struct ArgMax {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
bool simd_shuffle_down(bool data, uint16_t delta) {
|
|
||||||
return simd_shuffle_down(static_cast<uint32_t>(data), delta);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) {
|
|
||||||
return as_type<uint64_t>(simd_shuffle_down(as_type<uint2>(data), delta));
|
|
||||||
}
|
|
||||||
|
|
||||||
int64_t simd_shuffle_down(int64_t data, uint16_t delta) {
|
|
||||||
return as_type<int64_t>(simd_shuffle_down(as_type<uint2>(data), delta));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename U>
|
template <typename U>
|
||||||
IndexValPair<U> simd_shuffle_down(IndexValPair<U> data, uint16_t delta) {
|
IndexValPair<U> simd_shuffle_down(IndexValPair<U> data, uint16_t delta) {
|
||||||
return IndexValPair<U>(
|
return IndexValPair<U>(
|
||||||
|
@ -24,11 +24,59 @@ template <typename T, typename Op>
|
|||||||
device otype *out [[buffer(1)]], \
|
device otype *out [[buffer(1)]], \
|
||||||
uint tid [[thread_position_in_grid]]);
|
uint tid [[thread_position_in_grid]]);
|
||||||
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// All reduce
|
// All reduce
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||||
|
inline U per_thread_all_reduce(
|
||||||
|
const device T *in,
|
||||||
|
const device size_t& in_size,
|
||||||
|
uint gid,
|
||||||
|
uint grid_size) {
|
||||||
|
Op op;
|
||||||
|
U total_val = Op::init;
|
||||||
|
|
||||||
|
if (gid * N_READS < in_size) {
|
||||||
|
in += gid * N_READS;
|
||||||
|
|
||||||
|
int r = 0;
|
||||||
|
for(; r < (int)ceildiv(in_size, grid_size * N_READS) - 1; r++) {
|
||||||
|
U vals[N_READS] = {op.init};
|
||||||
|
|
||||||
|
for(int i = 0; i < N_READS; i++) {
|
||||||
|
vals[i] = static_cast<U>(in[i]);
|
||||||
|
}
|
||||||
|
for(int i = 0; i < N_READS; i++) {
|
||||||
|
total_val = op(vals[i], total_val);
|
||||||
|
}
|
||||||
|
|
||||||
|
in += grid_size * N_READS;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Separate case for the last set as we close the reduction size
|
||||||
|
size_t curr_idx = (gid + r * (size_t)grid_size) * N_READS;
|
||||||
|
if (curr_idx < in_size) {
|
||||||
|
int max_reads = in_size - curr_idx;
|
||||||
|
T vals[N_READS];
|
||||||
|
|
||||||
|
for(int i = 0, idx = 0; i < N_READS; i++, idx++) {
|
||||||
|
idx = idx < max_reads ? idx : max_reads - 1;
|
||||||
|
vals[i] = in[idx];
|
||||||
|
}
|
||||||
|
for(int i = 0; i < N_READS; i++) {
|
||||||
|
U val = i < max_reads ? vals[i] : Op::init;
|
||||||
|
total_val = op(static_cast<U>(val), total_val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return total_val;
|
||||||
|
}
|
||||||
|
|
||||||
|
// NB: This kernel assumes threads_per_threadgroup is at most
|
||||||
|
// 1024. This way with a simd_size of 32, we are guaranteed to
|
||||||
|
// complete the reduction in two steps of simd-level reductions.
|
||||||
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||||
[[kernel]] void all_reduce(
|
[[kernel]] void all_reduce(
|
||||||
const device T *in [[buffer(0)]],
|
const device T *in [[buffer(0)]],
|
||||||
@ -40,53 +88,18 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|||||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||||
// NB: this kernel assumes threads_per_threadgroup is at most
|
|
||||||
// 1024. This way with a simd_size of 32, we are guaranteed to
|
|
||||||
// complete the reduction in two steps of simd-level reductions.
|
|
||||||
|
|
||||||
Op op;
|
Op op;
|
||||||
threadgroup U local_vals[simd_size];
|
threadgroup U local_vals[simd_size];
|
||||||
|
|
||||||
U total_val = Op::init;
|
|
||||||
|
|
||||||
in += gid * N_READS;
|
U total_val = per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
|
||||||
|
|
||||||
int r = 0;
|
|
||||||
for(; r < (int)ceildiv(in_size, grid_size * N_READS) - 1; r++) {
|
|
||||||
U vals[N_READS] = {op.init};
|
|
||||||
|
|
||||||
for(int i = 0; i < N_READS; i++) {
|
|
||||||
vals[i] = static_cast<U>(in[i]);
|
|
||||||
}
|
|
||||||
for(int i = 0; i < N_READS; i++) {
|
|
||||||
total_val = op(vals[i], total_val);
|
|
||||||
}
|
|
||||||
|
|
||||||
in += grid_size * N_READS;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Separate case for the last set as we close the reduction size
|
|
||||||
size_t curr_idx = (gid + r * (size_t)grid_size) * N_READS;
|
|
||||||
if (curr_idx < in_size) {
|
|
||||||
int max_reads = in_size - curr_idx;
|
|
||||||
T vals[N_READS];
|
|
||||||
|
|
||||||
for(int i = 0, idx = 0; i < N_READS; i++, idx++) {
|
|
||||||
idx = idx < max_reads ? idx : max_reads - 1;
|
|
||||||
vals[i] = in[idx];
|
|
||||||
}
|
|
||||||
for(int i = 0; i < N_READS; i++) {
|
|
||||||
U val = i < max_reads ? vals[i] : Op::init;
|
|
||||||
total_val = op(static_cast<U>(val), total_val);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reduction within simd group
|
// Reduction within simd group
|
||||||
total_val = op.simd_reduce(total_val);
|
total_val = op.simd_reduce(total_val);
|
||||||
if (simd_lane_id == 0) {
|
if (simd_lane_id == 0) {
|
||||||
local_vals[simd_group_id] = total_val;
|
local_vals[simd_group_id] = total_val;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reduction within thread group
|
// Reduction within thread group
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
|
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
|
||||||
@ -98,6 +111,46 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||||
|
[[kernel]] void all_reduce_no_atomics(
|
||||||
|
const device T *in [[buffer(0)]],
|
||||||
|
device U *out [[buffer(1)]],
|
||||||
|
const device size_t& in_size [[buffer(2)]],
|
||||||
|
uint gid [[thread_position_in_grid]],
|
||||||
|
uint lid [[thread_position_in_threadgroup]],
|
||||||
|
uint grid_size [[threads_per_grid]],
|
||||||
|
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint thread_group_id [[threadgroup_position_in_grid]]) {
|
||||||
|
|
||||||
|
Op op;
|
||||||
|
threadgroup U local_vals[simd_size];
|
||||||
|
|
||||||
|
U total_val = per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
|
||||||
|
|
||||||
|
// Reduction within simd group (simd_add isn't supported for uint64/int64 types)
|
||||||
|
for (uint16_t lane_offset = simd_size/2; lane_offset > 0; lane_offset /= 2) {
|
||||||
|
total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
|
||||||
|
}
|
||||||
|
// Write simd group reduction results to local memory
|
||||||
|
if (simd_lane_id == 0) {
|
||||||
|
local_vals[simd_group_id] = total_val;
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Reduction of simdgroup reduction results within threadgroup.
|
||||||
|
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
|
||||||
|
for (uint16_t lane_offset = simd_size/2; lane_offset > 0; lane_offset /= 2) {
|
||||||
|
total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reduction across threadgroups
|
||||||
|
if (lid == 0) {
|
||||||
|
out[thread_group_id] = total_val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#define instantiate_all_reduce(name, itype, otype, op) \
|
#define instantiate_all_reduce(name, itype, otype, op) \
|
||||||
template [[host_name("all_reduce_" #name)]] \
|
template [[host_name("all_reduce_" #name)]] \
|
||||||
[[kernel]] void all_reduce<itype, otype, op>( \
|
[[kernel]] void all_reduce<itype, otype, op>( \
|
||||||
@ -111,11 +164,80 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|||||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||||
|
|
||||||
|
#define instantiate_all_reduce_no_atomics(name, itype, otype, op) \
|
||||||
|
template [[host_name("all_reduce_no_atomics_" #name)]] \
|
||||||
|
[[kernel]] void all_reduce_no_atomics<itype, otype, op>( \
|
||||||
|
const device itype *in [[buffer(0)]], \
|
||||||
|
device otype *out [[buffer(1)]], \
|
||||||
|
const device size_t& in_size [[buffer(2)]], \
|
||||||
|
uint gid [[thread_position_in_grid]], \
|
||||||
|
uint lid [[thread_position_in_threadgroup]], \
|
||||||
|
uint grid_size [[threads_per_grid]], \
|
||||||
|
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||||
|
uint thread_group_id [[threadgroup_position_in_grid]]);
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Row atomics
|
// Row atomics
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||||
|
inline U per_thread_row_reduce(
|
||||||
|
const device T *in,
|
||||||
|
const constant size_t& reduction_size,
|
||||||
|
const constant size_t& out_size,
|
||||||
|
const constant int* shape,
|
||||||
|
const constant size_t* strides,
|
||||||
|
const constant int& ndim,
|
||||||
|
uint lsize_x,
|
||||||
|
uint lid_x,
|
||||||
|
uint2 tid) {
|
||||||
|
|
||||||
|
Op op;
|
||||||
|
|
||||||
|
// Each threadgroup handles 1 reduction
|
||||||
|
// TODO: Specializing elem_to_loc would be slightly faster
|
||||||
|
int idx = tid.y * out_size + tid.x;
|
||||||
|
int extra_offset = elem_to_loc(idx, shape, strides, ndim);
|
||||||
|
in += extra_offset + lid_x * N_READS;
|
||||||
|
|
||||||
|
// The reduction is accumulated here
|
||||||
|
U total_val = Op::init;
|
||||||
|
|
||||||
|
// Loop over the reduction size within thread group
|
||||||
|
int r = 0;
|
||||||
|
for (; r < (int)ceildiv(reduction_size, N_READS*lsize_x) - 1; r++) {
|
||||||
|
T vals[N_READS];
|
||||||
|
for(int i = 0; i < N_READS; i++) {
|
||||||
|
vals[i] = in[i];
|
||||||
|
}
|
||||||
|
for(int i = 0; i < N_READS; i++) {
|
||||||
|
total_val = op(static_cast<U>(vals[i]), total_val);
|
||||||
|
}
|
||||||
|
|
||||||
|
in += lsize_x * N_READS;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Separate case for the last set as we close the reduction size
|
||||||
|
size_t reduction_index = (lid_x + (size_t)lsize_x * r) * N_READS;
|
||||||
|
if(reduction_index < reduction_size) {
|
||||||
|
int max_reads = reduction_size - reduction_index;
|
||||||
|
|
||||||
|
T vals[N_READS];
|
||||||
|
for(int i = 0; i < N_READS; i++) {
|
||||||
|
int idx = min(i, max_reads - 1);
|
||||||
|
vals[i] = static_cast<U>(in[idx]);
|
||||||
|
}
|
||||||
|
for(int i = 0; i < N_READS; i++) {
|
||||||
|
T val = i < max_reads ? vals[i] : Op::init;
|
||||||
|
total_val = op(static_cast<U>(val), total_val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return total_val;
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||||
[[kernel]] void row_reduce_general(
|
[[kernel]] void row_reduce_general(
|
||||||
const device T *in [[buffer(0)]],
|
const device T *in [[buffer(0)]],
|
||||||
@ -133,46 +255,9 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||||
|
|
||||||
Op op;
|
Op op;
|
||||||
|
|
||||||
// Each threadgroup handles 1 reduction
|
|
||||||
// TODO: Specializing elem_to_loc would be slightly faster
|
|
||||||
int idx = tid.y * out_size + tid.x;
|
|
||||||
int extra_offset = elem_to_loc(idx, shape, strides, ndim);
|
|
||||||
in += extra_offset + lid.x * N_READS;
|
|
||||||
|
|
||||||
// The reduction is accumulated here
|
|
||||||
U total_val = Op::init;
|
|
||||||
threadgroup U local_vals[simd_size];
|
threadgroup U local_vals[simd_size];
|
||||||
|
|
||||||
// Loop over the reduction size within thread group
|
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(in, reduction_size, out_size, shape, strides, ndim, lsize.x, lid.x, tid.xy);
|
||||||
int r = 0;
|
|
||||||
for (; r < (int)ceildiv(reduction_size, N_READS*lsize.x) - 1; r++) {
|
|
||||||
T vals[N_READS];
|
|
||||||
for(int i = 0; i < N_READS; i++) {
|
|
||||||
vals[i] = in[i];
|
|
||||||
}
|
|
||||||
for(int i = 0; i < N_READS; i++) {
|
|
||||||
total_val = op(static_cast<U>(vals[i]), total_val);
|
|
||||||
}
|
|
||||||
|
|
||||||
in += lsize.x * N_READS;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Separate case for the last set as we close the reduction size
|
|
||||||
size_t reduction_index = (lid.x + (size_t)lsize.x * r) * N_READS;
|
|
||||||
if(reduction_index < reduction_size) {
|
|
||||||
int max_reads = reduction_size - reduction_index;
|
|
||||||
|
|
||||||
T vals[N_READS];
|
|
||||||
for(int i = 0; i < N_READS; i++) {
|
|
||||||
int idx = min(i, max_reads - 1);
|
|
||||||
vals[i] = static_cast<U>(in[idx]);
|
|
||||||
}
|
|
||||||
for(int i = 0; i < N_READS; i++) {
|
|
||||||
T val = i < max_reads ? vals[i] : Op::init;
|
|
||||||
total_val = op(static_cast<U>(val), total_val);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
total_val = op.simd_reduce(total_val);
|
total_val = op.simd_reduce(total_val);
|
||||||
|
|
||||||
@ -194,6 +279,53 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||||
|
[[kernel]] void row_reduce_general_no_atomics(
|
||||||
|
const device T *in [[buffer(0)]],
|
||||||
|
device U *out [[buffer(1)]],
|
||||||
|
const constant size_t& reduction_size [[buffer(2)]],
|
||||||
|
const constant size_t& out_size [[buffer(3)]],
|
||||||
|
const constant int* shape [[buffer(4)]],
|
||||||
|
const constant size_t* strides [[buffer(5)]],
|
||||||
|
const constant int& ndim [[buffer(6)]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
|
uint3 lsize [[threads_per_threadgroup]],
|
||||||
|
uint3 gsize [[threads_per_grid]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
|
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||||
|
|
||||||
|
Op op;
|
||||||
|
|
||||||
|
threadgroup U local_vals[simd_size];
|
||||||
|
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(in, reduction_size, out_size, shape, strides, ndim, lsize.x, lid.x, tid.xy);
|
||||||
|
|
||||||
|
// Reduction within simd group - simd_add isn't supported for int64 types
|
||||||
|
for (uint16_t i = simd_size/2; i > 0; i /= 2) {
|
||||||
|
total_val = op(total_val, simd_shuffle_down(total_val, i));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare next level
|
||||||
|
if (simd_lane_id == 0) {
|
||||||
|
local_vals[simd_group_id] = total_val;
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Reduction within thread group
|
||||||
|
// Only needed if thread group has multiple simd groups
|
||||||
|
if(ceildiv(reduction_size, N_READS) > simd_size) {
|
||||||
|
total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init;
|
||||||
|
for (uint16_t i = simd_size/2; i > 0; i /= 2) {
|
||||||
|
total_val = op(total_val, simd_shuffle_down(total_val, i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Write row reduce output for threadgroup with 1st thread in thread group
|
||||||
|
if (lid.x == 0) {
|
||||||
|
out[(ceildiv(gsize.y, lsize.y) * tid.x) + tid.y] = total_val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#define instantiate_row_reduce_general(name, itype, otype, op) \
|
#define instantiate_row_reduce_general(name, itype, otype, op) \
|
||||||
template [[host_name("row_reduce_general_" #name)]] \
|
template [[host_name("row_reduce_general_" #name)]] \
|
||||||
[[kernel]] void row_reduce_general<itype, otype, op>( \
|
[[kernel]] void row_reduce_general<itype, otype, op>( \
|
||||||
@ -211,52 +343,59 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|||||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||||
|
|
||||||
|
#define instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \
|
||||||
|
template [[host_name("row_reduce_general_no_atomics_" #name)]] \
|
||||||
|
[[kernel]] void row_reduce_general_no_atomics<itype, otype, op>( \
|
||||||
|
const device itype *in [[buffer(0)]], \
|
||||||
|
device otype *out [[buffer(1)]], \
|
||||||
|
const constant size_t& reduction_size [[buffer(2)]], \
|
||||||
|
const constant size_t& out_size [[buffer(3)]], \
|
||||||
|
const constant int* shape [[buffer(4)]], \
|
||||||
|
const constant size_t* strides [[buffer(5)]], \
|
||||||
|
const constant int& ndim [[buffer(6)]], \
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]], \
|
||||||
|
uint3 lsize [[threads_per_threadgroup]], \
|
||||||
|
uint3 gsize [[threads_per_grid]], \
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||||
|
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Column reduce
|
// Column reduce
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||||
inline void _contiguous_strided_reduce(
|
inline U _contiguous_strided_reduce(
|
||||||
const device T *in,
|
const device T *in,
|
||||||
device mlx_atomic<U> *out,
|
threadgroup U *local_data,
|
||||||
threadgroup U *local_data,
|
uint in_idx,
|
||||||
uint in_idx,
|
uint reduction_size,
|
||||||
uint out_idx,
|
uint reduction_stride,
|
||||||
uint reduction_size,
|
uint2 tid,
|
||||||
uint reduction_stride,
|
uint2 lid,
|
||||||
uint2 tid,
|
|
||||||
uint2 lid,
|
|
||||||
uint2 lsize) {
|
uint2 lsize) {
|
||||||
|
|
||||||
Op op;
|
Op op;
|
||||||
T local_vals[N_READS];
|
U total_val = Op::init;
|
||||||
|
|
||||||
uint base_offset = (tid.y * lsize.y + lid.y) * N_READS;
|
uint base_offset = (tid.y * lsize.y + lid.y) * N_READS;
|
||||||
|
|
||||||
for(uint r = 0; r < N_READS; r++) {
|
|
||||||
uint offset = base_offset + r;
|
|
||||||
offset = offset < reduction_size ? offset : reduction_size - 1;
|
|
||||||
local_vals[r] = in[in_idx + offset * reduction_stride];
|
|
||||||
}
|
|
||||||
|
|
||||||
U total_val = Op::init;
|
|
||||||
for(uint r = 0; r < N_READS && (base_offset + r) < reduction_size; r++) {
|
for(uint r = 0; r < N_READS && (base_offset + r) < reduction_size; r++) {
|
||||||
total_val = op(static_cast<U>(total_val), local_vals[r]);
|
uint offset = base_offset + r;
|
||||||
|
total_val = op(static_cast<U>(total_val), in[in_idx + offset * reduction_stride]);
|
||||||
}
|
}
|
||||||
local_data[lsize.y * lid.x + lid.y] = total_val;
|
local_data[lsize.y * lid.x + lid.y] = total_val;
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
U val = Op::init;
|
||||||
if(lid.y == 0) {
|
if(lid.y == 0) {
|
||||||
U val = op.init;
|
// Perform reduction across columns in thread group
|
||||||
|
|
||||||
for(uint i = 0; i < lsize.y; i++) {
|
for(uint i = 0; i < lsize.y; i++) {
|
||||||
val = op(val, local_data[lsize.y * lid.x + i]);
|
val = op(val, local_data[lsize.y * lid.x + i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
op.atomic_update(out, val, out_idx);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return val;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||||
@ -265,13 +404,13 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
|||||||
device mlx_atomic<U> *out [[buffer(1)]],
|
device mlx_atomic<U> *out [[buffer(1)]],
|
||||||
const constant size_t& reduction_size [[buffer(2)]],
|
const constant size_t& reduction_size [[buffer(2)]],
|
||||||
const constant size_t& reduction_stride [[buffer(3)]],
|
const constant size_t& reduction_stride [[buffer(3)]],
|
||||||
const constant size_t& out_size [[buffer(4)]],
|
const constant size_t& out_size [[buffer(4)]],
|
||||||
const constant int* shape [[buffer(5)]],
|
const constant int* shape [[buffer(5)]],
|
||||||
const constant size_t* strides [[buffer(6)]],
|
const constant size_t* strides [[buffer(6)]],
|
||||||
const constant int& ndim [[buffer(7)]],
|
const constant int& ndim [[buffer(7)]],
|
||||||
threadgroup U *local_data [[threadgroup(0)]],
|
threadgroup U *local_data [[threadgroup(0)]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint3 lid [[thread_position_in_threadgroup]],
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
uint3 lsize [[threads_per_threadgroup]]) {
|
uint3 lsize [[threads_per_threadgroup]]) {
|
||||||
auto out_idx = tid.x * lsize.x + lid.x;
|
auto out_idx = tid.x * lsize.x + lid.x;
|
||||||
auto in_idx = elem_to_loc(
|
auto in_idx = elem_to_loc(
|
||||||
@ -281,18 +420,66 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
|||||||
ndim
|
ndim
|
||||||
);
|
);
|
||||||
|
|
||||||
|
Op op;
|
||||||
if(out_idx < out_size) {
|
if(out_idx < out_size) {
|
||||||
_contiguous_strided_reduce<T, U, Op, N_READS>(
|
U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
|
||||||
in,
|
in,
|
||||||
out,
|
local_data,
|
||||||
local_data,
|
in_idx,
|
||||||
in_idx,
|
reduction_size,
|
||||||
out_idx,
|
reduction_stride,
|
||||||
reduction_size,
|
tid.xy,
|
||||||
reduction_stride,
|
lid.xy,
|
||||||
tid.xy,
|
lsize.xy);
|
||||||
lid.xy,
|
|
||||||
lsize.xy);
|
// Write out reduction results generated by threadgroups working on specific output element, contiguously.
|
||||||
|
if (lid.y == 0) {
|
||||||
|
op.atomic_update(out, val, out_idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||||
|
[[kernel]] void col_reduce_general_no_atomics(
|
||||||
|
const device T *in [[buffer(0)]],
|
||||||
|
device U *out [[buffer(1)]],
|
||||||
|
const constant size_t& reduction_size [[buffer(2)]],
|
||||||
|
const constant size_t& reduction_stride [[buffer(3)]],
|
||||||
|
const constant size_t& out_size [[buffer(4)]],
|
||||||
|
const constant int* shape [[buffer(5)]],
|
||||||
|
const constant size_t* strides [[buffer(6)]],
|
||||||
|
const constant int& ndim [[buffer(7)]],
|
||||||
|
threadgroup U *local_data [[threadgroup(0)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
|
uint3 gid [[thread_position_in_grid]],
|
||||||
|
uint3 lsize [[threads_per_threadgroup]],
|
||||||
|
uint3 gsize [[threads_per_grid]]) {
|
||||||
|
auto out_idx = tid.x * lsize.x + lid.x;
|
||||||
|
auto in_idx = elem_to_loc(
|
||||||
|
out_idx + tid.z * out_size,
|
||||||
|
shape,
|
||||||
|
strides,
|
||||||
|
ndim
|
||||||
|
);
|
||||||
|
|
||||||
|
if(out_idx < out_size) {
|
||||||
|
U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
|
||||||
|
in,
|
||||||
|
local_data,
|
||||||
|
in_idx,
|
||||||
|
reduction_size,
|
||||||
|
reduction_stride,
|
||||||
|
tid.xy,
|
||||||
|
lid.xy,
|
||||||
|
lsize.xy);
|
||||||
|
|
||||||
|
// Write out reduction results generated by threadgroups working on specific output element, contiguously.
|
||||||
|
if (lid.y == 0) {
|
||||||
|
uint tgsize_y = ceildiv(gsize.y, lsize.y);
|
||||||
|
uint tgsize_z = ceildiv(gsize.z, lsize.z);
|
||||||
|
out[tgsize_y * tgsize_z * gid.x + tgsize_y * tid.z + tid.y] = val;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -312,6 +499,23 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
|||||||
uint3 lid [[thread_position_in_threadgroup]], \
|
uint3 lid [[thread_position_in_threadgroup]], \
|
||||||
uint3 lsize [[threads_per_threadgroup]]);
|
uint3 lsize [[threads_per_threadgroup]]);
|
||||||
|
|
||||||
|
#define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \
|
||||||
|
template [[host_name("col_reduce_general_no_atomics_" #name)]] \
|
||||||
|
[[kernel]] void col_reduce_general_no_atomics<itype, otype, op>( \
|
||||||
|
const device itype *in [[buffer(0)]], \
|
||||||
|
device otype *out [[buffer(1)]], \
|
||||||
|
const constant size_t& reduction_size [[buffer(2)]], \
|
||||||
|
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||||
|
const constant size_t& out_size [[buffer(4)]], \
|
||||||
|
const constant int* shape [[buffer(5)]], \
|
||||||
|
const constant size_t* strides [[buffer(6)]], \
|
||||||
|
const constant int& ndim [[buffer(7)]], \
|
||||||
|
threadgroup otype *local_data [[threadgroup(0)]], \
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]], \
|
||||||
|
uint3 gid [[thread_position_in_grid]], \
|
||||||
|
uint3 lsize [[threads_per_threadgroup]], \
|
||||||
|
uint3 gsize [[threads_per_grid]]);
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Instantiations
|
// Instantiations
|
||||||
@ -322,6 +526,15 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
|||||||
instantiate_row_reduce_general(name, itype, otype, op) \
|
instantiate_row_reduce_general(name, itype, otype, op) \
|
||||||
instantiate_col_reduce_general(name, itype, otype, op)
|
instantiate_col_reduce_general(name, itype, otype, op)
|
||||||
|
|
||||||
|
#define instantiate_reduce_no_atomics(name, itype, otype, op) \
|
||||||
|
instantiate_all_reduce_no_atomics(name, itype, otype, op) \
|
||||||
|
instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \
|
||||||
|
instantiate_col_reduce_general_no_atomics(name, itype, otype, op)
|
||||||
|
|
||||||
|
#define instantiate_same_reduce_no_atomics(name, tname, type, op) \
|
||||||
|
instantiate_init_reduce(name ##tname, type, op<type>) \
|
||||||
|
instantiate_reduce_no_atomics(name ##tname, type, type, op<type>)
|
||||||
|
|
||||||
#define instantiate_same_reduce(name, tname, type, op) \
|
#define instantiate_same_reduce(name, tname, type, op) \
|
||||||
instantiate_init_reduce(name ##tname, type, op<type>) \
|
instantiate_init_reduce(name ##tname, type, op<type>) \
|
||||||
instantiate_reduce(name ##tname, type, type, op<type>)
|
instantiate_reduce(name ##tname, type, type, op<type>)
|
||||||
@ -353,6 +566,9 @@ instantiate_same_reduce(sum, int32, int32_t, Sum)
|
|||||||
instantiate_same_reduce(sum, float16, half, Sum)
|
instantiate_same_reduce(sum, float16, half, Sum)
|
||||||
instantiate_same_reduce(sum, float32, float, Sum)
|
instantiate_same_reduce(sum, float32, float, Sum)
|
||||||
|
|
||||||
|
instantiate_same_reduce_no_atomics(sum, int64, int64_t, Sum)
|
||||||
|
instantiate_same_reduce_no_atomics(sum, uint64, uint64_t, Sum)
|
||||||
|
|
||||||
instantiate_same_reduce(prod, uint8, uint8_t, Prod)
|
instantiate_same_reduce(prod, uint8, uint8_t, Prod)
|
||||||
instantiate_same_reduce(prod, uint16, uint16_t, Prod)
|
instantiate_same_reduce(prod, uint16, uint16_t, Prod)
|
||||||
instantiate_same_reduce(prod, uint32, uint32_t, Prod)
|
instantiate_same_reduce(prod, uint32, uint32_t, Prod)
|
||||||
@ -362,6 +578,9 @@ instantiate_same_reduce(prod, int32, int32_t, Prod)
|
|||||||
instantiate_same_reduce(prod, float16, half, Prod)
|
instantiate_same_reduce(prod, float16, half, Prod)
|
||||||
instantiate_same_reduce(prod, float32, float, Prod)
|
instantiate_same_reduce(prod, float32, float, Prod)
|
||||||
|
|
||||||
|
instantiate_same_reduce_no_atomics(prod, int64, int64_t, Prod)
|
||||||
|
instantiate_same_reduce_no_atomics(prod, uint64, uint64_t, Prod)
|
||||||
|
|
||||||
instantiate_same_reduce(sum, bfloat16, bfloat16_t, Sum)
|
instantiate_same_reduce(sum, bfloat16, bfloat16_t, Sum)
|
||||||
instantiate_same_reduce(prod, bfloat16, bfloat16_t, Prod)
|
instantiate_same_reduce(prod, bfloat16, bfloat16_t, Prod)
|
||||||
|
|
||||||
@ -381,6 +600,9 @@ instantiate_same_reduce(min_, int32, int32_t, Min)
|
|||||||
instantiate_same_reduce(min_, float16, half, Min)
|
instantiate_same_reduce(min_, float16, half, Min)
|
||||||
instantiate_same_reduce(min_, float32, float, Min)
|
instantiate_same_reduce(min_, float32, float, Min)
|
||||||
|
|
||||||
|
instantiate_same_reduce_no_atomics(min_, int64, int64_t, Min)
|
||||||
|
instantiate_same_reduce_no_atomics(min_, uint64, uint64_t, Min)
|
||||||
|
|
||||||
instantiate_same_reduce(max_, uint8, uint8_t, Max)
|
instantiate_same_reduce(max_, uint8, uint8_t, Max)
|
||||||
instantiate_same_reduce(max_, uint16, uint16_t, Max)
|
instantiate_same_reduce(max_, uint16, uint16_t, Max)
|
||||||
instantiate_same_reduce(max_, uint32, uint32_t, Max)
|
instantiate_same_reduce(max_, uint32, uint32_t, Max)
|
||||||
@ -390,5 +612,8 @@ instantiate_same_reduce(max_, int32, int32_t, Max)
|
|||||||
instantiate_same_reduce(max_, float16, half, Max)
|
instantiate_same_reduce(max_, float16, half, Max)
|
||||||
instantiate_same_reduce(max_, float32, float, Max)
|
instantiate_same_reduce(max_, float32, float, Max)
|
||||||
|
|
||||||
|
instantiate_same_reduce_no_atomics(max_, int64, int64_t, Max)
|
||||||
|
instantiate_same_reduce_no_atomics(max_, uint64, uint64_t, Max)
|
||||||
|
|
||||||
instantiate_same_reduce(min_, bfloat16, bfloat16_t, Min)
|
instantiate_same_reduce(min_, bfloat16, bfloat16_t, Min)
|
||||||
instantiate_same_reduce(max_, bfloat16, bfloat16_t, Max)
|
instantiate_same_reduce(max_, bfloat16, bfloat16_t, Max)
|
||||||
|
@ -256,3 +256,21 @@ inline bfloat16_t log1p(bfloat16_t x) {
|
|||||||
|
|
||||||
return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
|
return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// SIMD shuffle ops
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
inline uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) {
|
||||||
|
return as_type<uint64_t>(
|
||||||
|
metal::simd_shuffle_down(as_type<uint2>(data), delta));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline int64_t simd_shuffle_down(int64_t data, uint16_t delta) {
|
||||||
|
return as_type<int64_t>(
|
||||||
|
metal::simd_shuffle_down(as_type<uint2>(data), delta));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bool simd_shuffle_down(bool data, uint16_t delta) {
|
||||||
|
return simd_shuffle_down(static_cast<uint32_t>(data), delta);
|
||||||
|
}
|
@ -28,35 +28,40 @@ inline auto safe_divup(size_t n, size_t m) {
|
|||||||
return safe_div(n, m) * m;
|
return safe_div(n, m) * m;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline bool is_64b_int(Dtype dtype) {
|
||||||
|
return dtype == int64 || dtype == uint64;
|
||||||
|
}
|
||||||
|
|
||||||
// All Reduce
|
// All Reduce
|
||||||
void all_reduce_dispatch(
|
void all_reduce_dispatch(
|
||||||
const array& in,
|
const array& in,
|
||||||
array& out,
|
array& out,
|
||||||
const std::string& op_name,
|
const std::string& op_name,
|
||||||
MTL::ComputeCommandEncoder* compute_encoder,
|
MTL::ComputeCommandEncoder* compute_encoder,
|
||||||
metal::Device& d) {
|
metal::Device& d,
|
||||||
// Get kernel and encode buffers
|
const Stream& s) {
|
||||||
size_t in_size = in.size();
|
Dtype out_dtype = out.dtype();
|
||||||
auto kernel = d.get_kernel("all_reduce_" + op_name + type_to_name(in));
|
bool is_out_64b_int = is_64b_int(out_dtype);
|
||||||
|
auto kernel = (is_out_64b_int)
|
||||||
|
? d.get_kernel("all_reduce_no_atomics_" + op_name + type_to_name(in))
|
||||||
|
: d.get_kernel("all_reduce_" + op_name + type_to_name(in));
|
||||||
|
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
set_array_buffer(compute_encoder, in, 0);
|
|
||||||
set_array_buffer(compute_encoder, out, 1);
|
|
||||||
compute_encoder->setBytes(&in_size, sizeof(size_t), 2);
|
|
||||||
|
|
||||||
// Set grid dimensions
|
|
||||||
|
|
||||||
// We make sure each thread has enough to do by making it read in
|
// We make sure each thread has enough to do by making it read in
|
||||||
// at least n_reads inputs
|
// at least n_reads inputs
|
||||||
int n_reads = REDUCE_N_READS;
|
int n_reads = REDUCE_N_READS;
|
||||||
|
size_t in_size = in.size();
|
||||||
|
|
||||||
// mod_in_size gives us the groups of n_reads needed to go over the entire
|
// mod_in_size gives us the groups of n_reads needed to go over the entire
|
||||||
// input
|
// input
|
||||||
uint mod_in_size = (in_size + n_reads - 1) / n_reads;
|
uint mod_in_size = (in_size + n_reads - 1) / n_reads;
|
||||||
|
|
||||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
thread_group_size =
|
thread_group_size =
|
||||||
mod_in_size > thread_group_size ? thread_group_size : mod_in_size;
|
mod_in_size > thread_group_size ? thread_group_size : mod_in_size;
|
||||||
|
uint simd_size = kernel->threadExecutionWidth();
|
||||||
|
thread_group_size =
|
||||||
|
((thread_group_size + simd_size - 1) / simd_size) * simd_size;
|
||||||
|
|
||||||
// If the number of thread groups needed exceeds 1024, we reuse threads groups
|
// If the number of thread groups needed exceeds 1024, we reuse threads groups
|
||||||
uint n_thread_groups = safe_div(mod_in_size, thread_group_size);
|
uint n_thread_groups = safe_div(mod_in_size, thread_group_size);
|
||||||
@ -66,7 +71,52 @@ void all_reduce_dispatch(
|
|||||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||||
|
|
||||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
// Encode buffers and dispatch
|
||||||
|
if (is_out_64b_int == false || n_thread_groups == 1) {
|
||||||
|
set_array_buffer(compute_encoder, in, 0);
|
||||||
|
set_array_buffer(compute_encoder, out, 1);
|
||||||
|
compute_encoder->setBytes(&in_size, sizeof(size_t), 2);
|
||||||
|
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
// Allocate intermediate array to store partial reduction results
|
||||||
|
size_t intermediate_size = n_thread_groups;
|
||||||
|
array intermediate =
|
||||||
|
array({static_cast<int>(intermediate_size)}, out_dtype, nullptr, {});
|
||||||
|
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
|
||||||
|
std::vector<array> intermediates = {intermediate};
|
||||||
|
|
||||||
|
// First dispatch
|
||||||
|
set_array_buffer(compute_encoder, in, 0);
|
||||||
|
set_array_buffer(compute_encoder, intermediate, 1);
|
||||||
|
compute_encoder->setBytes(&in_size, sizeof(size_t), 2);
|
||||||
|
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||||
|
|
||||||
|
// Second pass to reduce intermediate reduction results written to DRAM
|
||||||
|
set_array_buffer(compute_encoder, intermediate, 0);
|
||||||
|
set_array_buffer(compute_encoder, out, 1);
|
||||||
|
compute_encoder->setBytes(&intermediate_size, sizeof(size_t), 2);
|
||||||
|
|
||||||
|
mod_in_size = (intermediate_size + n_reads - 1) / n_reads;
|
||||||
|
|
||||||
|
thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
|
thread_group_size =
|
||||||
|
mod_in_size > thread_group_size ? thread_group_size : mod_in_size;
|
||||||
|
thread_group_size =
|
||||||
|
((thread_group_size + simd_size - 1) / simd_size) * simd_size;
|
||||||
|
|
||||||
|
// If the number of thread groups needed exceeds 1024, we reuse threads
|
||||||
|
// groups
|
||||||
|
nthreads = thread_group_size;
|
||||||
|
group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||||
|
grid_dims = MTL::Size(nthreads, 1, 1);
|
||||||
|
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||||
|
|
||||||
|
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||||
|
[intermediates](MTL::CommandBuffer*) mutable {
|
||||||
|
intermediates.clear();
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void row_reduce_general_dispatch(
|
void row_reduce_general_dispatch(
|
||||||
@ -76,22 +126,31 @@ void row_reduce_general_dispatch(
|
|||||||
const ReductionPlan& plan,
|
const ReductionPlan& plan,
|
||||||
const std::vector<int>& axes,
|
const std::vector<int>& axes,
|
||||||
MTL::ComputeCommandEncoder* compute_encoder,
|
MTL::ComputeCommandEncoder* compute_encoder,
|
||||||
metal::Device& d) {
|
metal::Device& d,
|
||||||
auto kernel =
|
const Stream& s) {
|
||||||
d.get_kernel("row_reduce_general_" + op_name + type_to_name(in));
|
Dtype out_dtype = out.dtype();
|
||||||
|
bool is_out_64b_int = is_64b_int(out_dtype);
|
||||||
|
auto kernel = (is_out_64b_int)
|
||||||
|
? d.get_kernel(
|
||||||
|
"row_reduce_general_no_atomics_" + op_name + type_to_name(in))
|
||||||
|
: d.get_kernel("row_reduce_general_" + op_name + type_to_name(in));
|
||||||
|
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
// Prepare the arguments for the kernel
|
// Prepare the arguments for the kernel
|
||||||
int n_reads = REDUCE_N_READS;
|
int n_reads = REDUCE_N_READS;
|
||||||
size_t reduction_size = plan.shape.back();
|
size_t reduction_size = plan.shape.back();
|
||||||
size_t out_size = out.size();
|
|
||||||
auto shape = plan.shape;
|
auto shape = plan.shape;
|
||||||
auto strides = plan.strides;
|
auto strides = plan.strides;
|
||||||
|
|
||||||
shape.pop_back();
|
shape.pop_back();
|
||||||
strides.pop_back();
|
strides.pop_back();
|
||||||
|
|
||||||
size_t non_row_reductions = 1;
|
size_t non_row_reductions = 1;
|
||||||
for (auto s : shape) {
|
for (auto s : shape) {
|
||||||
non_row_reductions *= static_cast<size_t>(s);
|
non_row_reductions *= static_cast<size_t>(s);
|
||||||
}
|
}
|
||||||
|
size_t out_size = out.size();
|
||||||
auto [rem_shape, rem_strides] = shapes_without_reduction_axes(in, axes);
|
auto [rem_shape, rem_strides] = shapes_without_reduction_axes(in, axes);
|
||||||
for (auto s : rem_shape) {
|
for (auto s : rem_shape) {
|
||||||
shape.push_back(s);
|
shape.push_back(s);
|
||||||
@ -101,16 +160,6 @@ void row_reduce_general_dispatch(
|
|||||||
}
|
}
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
|
|
||||||
// Set the arguments for the kernel
|
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
|
||||||
set_array_buffer(compute_encoder, in, 0);
|
|
||||||
set_array_buffer(compute_encoder, out, 1);
|
|
||||||
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
|
||||||
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
|
|
||||||
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4);
|
|
||||||
compute_encoder->setBytes(strides.data(), strides.size() * sizeof(size_t), 5);
|
|
||||||
compute_encoder->setBytes(&ndim, sizeof(int), 6);
|
|
||||||
|
|
||||||
// Each thread group is responsible for 1 output
|
// Each thread group is responsible for 1 output
|
||||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
thread_group_size =
|
thread_group_size =
|
||||||
@ -127,7 +176,88 @@ void row_reduce_general_dispatch(
|
|||||||
MTL::Size grid_dims = MTL::Size(n_threads, non_row_reductions, 1);
|
MTL::Size grid_dims = MTL::Size(n_threads, non_row_reductions, 1);
|
||||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||||
|
|
||||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
if (is_out_64b_int == false || non_row_reductions == 1) {
|
||||||
|
// Set the arguments for the kernel
|
||||||
|
set_array_buffer(compute_encoder, in, 0);
|
||||||
|
set_array_buffer(compute_encoder, out, 1);
|
||||||
|
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
||||||
|
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
|
||||||
|
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4);
|
||||||
|
compute_encoder->setBytes(
|
||||||
|
strides.data(), strides.size() * sizeof(size_t), 5);
|
||||||
|
compute_encoder->setBytes(&ndim, sizeof(int), 6);
|
||||||
|
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
// Allocate intermediate array to store partial reduction results
|
||||||
|
array intermediate = array(
|
||||||
|
{static_cast<int>(out.size()), static_cast<int>(non_row_reductions)},
|
||||||
|
out_dtype,
|
||||||
|
nullptr,
|
||||||
|
{});
|
||||||
|
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
|
||||||
|
std::vector<array> intermediates = {intermediate};
|
||||||
|
|
||||||
|
// Set the arguments for the kernel
|
||||||
|
set_array_buffer(compute_encoder, in, 0);
|
||||||
|
set_array_buffer(compute_encoder, intermediate, 1);
|
||||||
|
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
||||||
|
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
|
||||||
|
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4);
|
||||||
|
compute_encoder->setBytes(
|
||||||
|
strides.data(), strides.size() * sizeof(size_t), 5);
|
||||||
|
compute_encoder->setBytes(&ndim, sizeof(int), 6);
|
||||||
|
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||||
|
|
||||||
|
// Set up second dispatch
|
||||||
|
reduction_size = non_row_reductions;
|
||||||
|
out_size = 1;
|
||||||
|
|
||||||
|
// Shape of axes that aren't participating in reduction remains unchanged.
|
||||||
|
std::vector<int> new_shape = rem_shape;
|
||||||
|
|
||||||
|
// Update their strides since they'll be different post partial reduction in
|
||||||
|
// first compute dispatch.
|
||||||
|
std::vector<size_t> new_strides = rem_strides;
|
||||||
|
new_strides.back() = reduction_size;
|
||||||
|
for (int i = new_shape.size() - 2; i >= 0; i--) {
|
||||||
|
new_strides[i] = new_shape[i + 1] * new_strides[i + 1];
|
||||||
|
}
|
||||||
|
ndim = new_shape.size();
|
||||||
|
|
||||||
|
// Set the arguments for the kernel
|
||||||
|
set_array_buffer(compute_encoder, intermediate, 0);
|
||||||
|
set_array_buffer(compute_encoder, out, 1);
|
||||||
|
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
||||||
|
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
|
||||||
|
compute_encoder->setBytes(
|
||||||
|
new_shape.data(), new_shape.size() * sizeof(int), 4);
|
||||||
|
compute_encoder->setBytes(
|
||||||
|
new_strides.data(), new_strides.size() * sizeof(size_t), 5);
|
||||||
|
compute_encoder->setBytes(&ndim, sizeof(int), 6);
|
||||||
|
|
||||||
|
// Each thread group is responsible for 1 output
|
||||||
|
thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
|
thread_group_size =
|
||||||
|
std::min((reduction_size + n_reads - 1) / n_reads, thread_group_size);
|
||||||
|
|
||||||
|
// Align thread group size with simd_size
|
||||||
|
thread_group_size =
|
||||||
|
(thread_group_size + simd_size - 1) / simd_size * simd_size;
|
||||||
|
assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup());
|
||||||
|
|
||||||
|
// Launch enough thread groups for each output
|
||||||
|
n_threads = thread_group_size;
|
||||||
|
grid_dims = MTL::Size(n_threads, out.size(), 1);
|
||||||
|
group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||||
|
|
||||||
|
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||||
|
|
||||||
|
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||||
|
[intermediates](MTL::CommandBuffer*) mutable {
|
||||||
|
intermediates.clear();
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void strided_reduce_general_dispatch(
|
void strided_reduce_general_dispatch(
|
||||||
@ -137,9 +267,16 @@ void strided_reduce_general_dispatch(
|
|||||||
const ReductionPlan& plan,
|
const ReductionPlan& plan,
|
||||||
const std::vector<int>& axes,
|
const std::vector<int>& axes,
|
||||||
MTL::ComputeCommandEncoder* compute_encoder,
|
MTL::ComputeCommandEncoder* compute_encoder,
|
||||||
metal::Device& d) {
|
metal::Device& d,
|
||||||
auto kernel =
|
const Stream& s) {
|
||||||
d.get_kernel("col_reduce_general_" + op_name + type_to_name(in));
|
Dtype out_dtype = out.dtype();
|
||||||
|
bool is_out_64b_int = is_64b_int(out_dtype);
|
||||||
|
auto kernel = (is_out_64b_int)
|
||||||
|
? d.get_kernel(
|
||||||
|
"col_reduce_general_no_atomics_" + op_name + type_to_name(in))
|
||||||
|
: d.get_kernel("col_reduce_general_" + op_name + type_to_name(in));
|
||||||
|
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
// Prepare the arguments for the kernel
|
// Prepare the arguments for the kernel
|
||||||
size_t reduction_size = plan.shape.back();
|
size_t reduction_size = plan.shape.back();
|
||||||
@ -162,19 +299,7 @@ void strided_reduce_general_dispatch(
|
|||||||
}
|
}
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
|
|
||||||
// Set the arguments for the kernel
|
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
|
||||||
set_array_buffer(compute_encoder, in, 0);
|
|
||||||
set_array_buffer(compute_encoder, out, 1);
|
|
||||||
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
|
||||||
compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3);
|
|
||||||
compute_encoder->setBytes(&out_size, sizeof(size_t), 4);
|
|
||||||
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5);
|
|
||||||
compute_encoder->setBytes(strides.data(), strides.size() * sizeof(size_t), 6);
|
|
||||||
compute_encoder->setBytes(&ndim, sizeof(int), 7);
|
|
||||||
|
|
||||||
// Select block dimensions
|
// Select block dimensions
|
||||||
|
|
||||||
// Each thread reads 16 inputs to give it more work
|
// Each thread reads 16 inputs to give it more work
|
||||||
uint n_inputs_per_thread = REDUCE_N_READS;
|
uint n_inputs_per_thread = REDUCE_N_READS;
|
||||||
uint n_threads_per_output =
|
uint n_threads_per_output =
|
||||||
@ -183,14 +308,22 @@ void strided_reduce_general_dispatch(
|
|||||||
// We spread outputs over the x dimension and inputs over the y dimension
|
// We spread outputs over the x dimension and inputs over the y dimension
|
||||||
// Threads with the same lid.x in a given threadgroup work on the same
|
// Threads with the same lid.x in a given threadgroup work on the same
|
||||||
// output and each thread in the y dimension accumulates for that output
|
// output and each thread in the y dimension accumulates for that output
|
||||||
|
|
||||||
|
// Threads with same lid.x, i.e. each column of threads work on same output
|
||||||
uint threadgroup_dim_x = std::min(out_size, 128ul);
|
uint threadgroup_dim_x = std::min(out_size, 128ul);
|
||||||
|
|
||||||
|
// Number of threads along y, is dependent on number of reductions needed.
|
||||||
uint threadgroup_dim_y =
|
uint threadgroup_dim_y =
|
||||||
kernel->maxTotalThreadsPerThreadgroup() / threadgroup_dim_x;
|
kernel->maxTotalThreadsPerThreadgroup() / threadgroup_dim_x;
|
||||||
threadgroup_dim_y = std::min(n_threads_per_output, threadgroup_dim_y);
|
threadgroup_dim_y = std::min(n_threads_per_output, threadgroup_dim_y);
|
||||||
|
|
||||||
|
// Derive number of thread groups along x, based on how many threads we need
|
||||||
|
// along x
|
||||||
uint n_threadgroups_x =
|
uint n_threadgroups_x =
|
||||||
(out_size + threadgroup_dim_x - 1) / threadgroup_dim_x;
|
(out_size + threadgroup_dim_x - 1) / threadgroup_dim_x;
|
||||||
|
|
||||||
|
// Derive number of thread groups along y based on how many threads we need
|
||||||
|
// along y
|
||||||
uint n_threadgroups_y =
|
uint n_threadgroups_y =
|
||||||
(n_threads_per_output + threadgroup_dim_y - 1) / threadgroup_dim_y;
|
(n_threads_per_output + threadgroup_dim_y - 1) / threadgroup_dim_y;
|
||||||
|
|
||||||
@ -199,18 +332,122 @@ void strided_reduce_general_dispatch(
|
|||||||
MTL::Size(n_threadgroups_x, n_threadgroups_y, non_col_reductions);
|
MTL::Size(n_threadgroups_x, n_threadgroups_y, non_col_reductions);
|
||||||
MTL::Size group_dims = MTL::Size(threadgroup_dim_x, threadgroup_dim_y, 1);
|
MTL::Size group_dims = MTL::Size(threadgroup_dim_x, threadgroup_dim_y, 1);
|
||||||
|
|
||||||
// We set shared memory to be exploited here for reductions within a
|
if (is_out_64b_int == false) {
|
||||||
// threadgroup - each thread must be able to update its accumulated output
|
// Set the arguments for the kernel
|
||||||
// Note: Each threadgroup should have 32kB of data in threadgroup memory
|
set_array_buffer(compute_encoder, in, 0);
|
||||||
// and threadgroup_dim_x * threadgroup_dim_y <= 1024 by design
|
set_array_buffer(compute_encoder, out, 1);
|
||||||
// This should be fine for floats, but we might need to revisit
|
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
||||||
// if we ever come to doubles. In that case, we should also cut
|
compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3);
|
||||||
// down the number of threads we launch in a threadgroup
|
compute_encoder->setBytes(&out_size, sizeof(size_t), 4);
|
||||||
compute_encoder->setThreadgroupMemoryLength(
|
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5);
|
||||||
safe_divup(threadgroup_dim_x * threadgroup_dim_y * out.itemsize(), 16),
|
compute_encoder->setBytes(
|
||||||
0);
|
strides.data(), strides.size() * sizeof(size_t), 6);
|
||||||
|
compute_encoder->setBytes(&ndim, sizeof(int), 7);
|
||||||
|
|
||||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
// We set shared memory to be exploited here for reductions within a
|
||||||
|
// threadgroup - each thread must be able to update its accumulated output
|
||||||
|
// Note: Each threadgroup should have 32kB of data in threadgroup memory
|
||||||
|
// and threadgroup_dim_x * threadgroup_dim_y <= 1024 by design
|
||||||
|
// This should be fine for floats, but we might need to revisit
|
||||||
|
// if we ever come to doubles. In that case, we should also cut
|
||||||
|
// down the number of threads we launch in a threadgroup
|
||||||
|
compute_encoder->setThreadgroupMemoryLength(
|
||||||
|
safe_divup(threadgroup_dim_x * threadgroup_dim_y * out.itemsize(), 16),
|
||||||
|
0);
|
||||||
|
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
// Allocate intermediate array to store reduction results from all thread
|
||||||
|
// groups
|
||||||
|
array intermediate = array(
|
||||||
|
{static_cast<int>(out.size()),
|
||||||
|
static_cast<int>(n_threadgroups_y * non_col_reductions)},
|
||||||
|
out_dtype,
|
||||||
|
nullptr,
|
||||||
|
{});
|
||||||
|
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
|
||||||
|
std::vector<array> intermediates = {intermediate};
|
||||||
|
|
||||||
|
// Set the arguments for the kernel
|
||||||
|
set_array_buffer(compute_encoder, in, 0);
|
||||||
|
set_array_buffer(compute_encoder, intermediate, 1);
|
||||||
|
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
||||||
|
compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3);
|
||||||
|
compute_encoder->setBytes(&out_size, sizeof(size_t), 4);
|
||||||
|
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5);
|
||||||
|
compute_encoder->setBytes(
|
||||||
|
strides.data(), strides.size() * sizeof(size_t), 6);
|
||||||
|
compute_encoder->setBytes(&ndim, sizeof(int), 7);
|
||||||
|
|
||||||
|
// We set shared memory to be exploited here for reductions within a
|
||||||
|
// threadgroup - each thread must be able to update its accumulated output
|
||||||
|
// Note: Each threadgroup should have 32kB of data in threadgroup memory
|
||||||
|
// and threadgroup_dim_x * threadgroup_dim_y <= 1024 by design
|
||||||
|
// This should be fine for floats, but we might need to revisit
|
||||||
|
// if we ever come to doubles. In that case, we should also cut
|
||||||
|
// down the number of threads we launch in a threadgroup
|
||||||
|
compute_encoder->setThreadgroupMemoryLength(
|
||||||
|
safe_divup(threadgroup_dim_x * threadgroup_dim_y * out.itemsize(), 16),
|
||||||
|
0);
|
||||||
|
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||||
|
|
||||||
|
// Perform second pass of reductions
|
||||||
|
// Reduce results of threadgroups along y, z from first pass, that
|
||||||
|
// collectively work on each output element.
|
||||||
|
reduction_size = n_threadgroups_y * non_col_reductions;
|
||||||
|
out_size = 1;
|
||||||
|
|
||||||
|
// Shape of axes that aren't participating in reduction remains unchanged.
|
||||||
|
std::vector<int> new_shape = rem_shape;
|
||||||
|
|
||||||
|
// Update their strides since they'll be different after a partial reduction
|
||||||
|
// post first compute dispatch.
|
||||||
|
std::vector<size_t> new_strides = rem_strides;
|
||||||
|
new_strides.back() = reduction_size;
|
||||||
|
for (int i = new_shape.size() - 2; i >= 0; i--) {
|
||||||
|
new_strides[i] = new_shape[i + 1] * new_strides[i + 1];
|
||||||
|
}
|
||||||
|
ndim = new_shape.size();
|
||||||
|
|
||||||
|
auto row_reduce_kernel = d.get_kernel(
|
||||||
|
"row_reduce_general_no_atomics_" + op_name +
|
||||||
|
type_to_name(intermediate));
|
||||||
|
compute_encoder->setComputePipelineState(row_reduce_kernel);
|
||||||
|
set_array_buffer(compute_encoder, intermediate, 0);
|
||||||
|
set_array_buffer(compute_encoder, out, 1);
|
||||||
|
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
||||||
|
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
|
||||||
|
compute_encoder->setBytes(
|
||||||
|
new_shape.data(), new_shape.size() * sizeof(int), 4);
|
||||||
|
compute_encoder->setBytes(
|
||||||
|
new_strides.data(), new_strides.size() * sizeof(size_t), 5);
|
||||||
|
compute_encoder->setBytes(&ndim, sizeof(int), 6);
|
||||||
|
|
||||||
|
// Each thread group is responsible for 1 output
|
||||||
|
size_t n_reads = REDUCE_N_READS;
|
||||||
|
size_t thread_group_size =
|
||||||
|
row_reduce_kernel->maxTotalThreadsPerThreadgroup();
|
||||||
|
thread_group_size =
|
||||||
|
std::min((reduction_size + n_reads - 1) / n_reads, thread_group_size);
|
||||||
|
|
||||||
|
// Align thread group size with simd_size
|
||||||
|
uint simd_size = row_reduce_kernel->threadExecutionWidth();
|
||||||
|
thread_group_size =
|
||||||
|
(thread_group_size + simd_size - 1) / simd_size * simd_size;
|
||||||
|
assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup());
|
||||||
|
|
||||||
|
// Launch enough thread groups for each output
|
||||||
|
uint n_threads = thread_group_size;
|
||||||
|
grid_dims = MTL::Size(n_threads, out.size(), 1);
|
||||||
|
group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||||
|
|
||||||
|
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||||
|
|
||||||
|
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||||
|
[intermediates](MTL::CommandBuffer*) mutable {
|
||||||
|
intermediates.clear();
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@ -223,14 +460,6 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
array in = inputs[0];
|
array in = inputs[0];
|
||||||
|
|
||||||
// TODO: Allow specific row and column reductions with types disabled
|
|
||||||
// due to atomics ?
|
|
||||||
if (size_of(in.dtype()) == 8) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[Reduce::eval_gpu] Does not support " << in.dtype();
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make sure no identity reductions trickle down here
|
// Make sure no identity reductions trickle down here
|
||||||
assert(!axes_.empty());
|
assert(!axes_.empty());
|
||||||
|
|
||||||
@ -297,7 +526,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
// Reducing over everything and the data is all there no broadcasting or
|
// Reducing over everything and the data is all there no broadcasting or
|
||||||
// slicing etc.
|
// slicing etc.
|
||||||
if (plan.type == ContiguousAllReduce) {
|
if (plan.type == ContiguousAllReduce) {
|
||||||
all_reduce_dispatch(in, out, op_name, compute_encoder, d);
|
all_reduce_dispatch(in, out, op_name, compute_encoder, d, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
// At least the last dimension is row contiguous and we are reducing over
|
// At least the last dimension is row contiguous and we are reducing over
|
||||||
@ -305,7 +534,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
else if (
|
else if (
|
||||||
plan.type == ContiguousReduce || plan.type == GeneralContiguousReduce) {
|
plan.type == ContiguousReduce || plan.type == GeneralContiguousReduce) {
|
||||||
row_reduce_general_dispatch(
|
row_reduce_general_dispatch(
|
||||||
in, out, op_name, plan, axes_, compute_encoder, d);
|
in, out, op_name, plan, axes_, compute_encoder, d, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
// At least the last two dimensions are contiguous and we are doing a
|
// At least the last two dimensions are contiguous and we are doing a
|
||||||
@ -314,7 +543,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
plan.type == ContiguousStridedReduce ||
|
plan.type == ContiguousStridedReduce ||
|
||||||
plan.type == GeneralStridedReduce) {
|
plan.type == GeneralStridedReduce) {
|
||||||
strided_reduce_general_dispatch(
|
strided_reduce_general_dispatch(
|
||||||
in, out, op_name, plan, axes_, compute_encoder, d);
|
in, out, op_name, plan, axes_, compute_encoder, d, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!copies.empty()) {
|
if (!copies.empty()) {
|
||||||
|
@ -55,6 +55,8 @@ class TestReduce(mlx_tests.MLXTestCase):
|
|||||||
"uint8",
|
"uint8",
|
||||||
"uint16",
|
"uint16",
|
||||||
"uint32",
|
"uint32",
|
||||||
|
"int64",
|
||||||
|
"uint64",
|
||||||
]
|
]
|
||||||
float_dtypes = ["float32"]
|
float_dtypes = ["float32"]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user