mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 09:51:17 +08:00
Refactor the reduction kernels (#277)
This commit is contained in:
parent
22fee5a383
commit
9e6b8c9f48
@ -125,6 +125,14 @@ if __name__ == "__main__":
|
|||||||
compare_filtered("sum_axis --size 16x128x1024 --axis 1")
|
compare_filtered("sum_axis --size 16x128x1024 --axis 1")
|
||||||
compare_filtered("sum_axis --size 16x128x1024 --axis 0 --cpu")
|
compare_filtered("sum_axis --size 16x128x1024 --axis 0 --cpu")
|
||||||
compare_filtered("sum_axis --size 16x128x1024 --axis 0")
|
compare_filtered("sum_axis --size 16x128x1024 --axis 0")
|
||||||
|
compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --cpu")
|
||||||
|
compare_filtered("sum_axis --size 16x128x1024 --axis 0,1")
|
||||||
|
compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --cpu")
|
||||||
|
compare_filtered("sum_axis --size 16x128x1024 --axis 0,2")
|
||||||
|
compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --transpose 0,2,1 --cpu")
|
||||||
|
compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --transpose 0,2,1")
|
||||||
|
compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --transpose 0,2,1 --cpu")
|
||||||
|
compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --transpose 0,2,1")
|
||||||
compare_filtered("argmax --size 10x1024x128 --axis 1 --cpu")
|
compare_filtered("argmax --size 10x1024x128 --axis 1 --cpu")
|
||||||
compare_filtered("argmax --size 10x1024x128 --axis 1")
|
compare_filtered("argmax --size 10x1024x128 --axis 1")
|
||||||
compare_filtered("argmax --size 10x1024x128 --axis 2 --cpu")
|
compare_filtered("argmax --size 10x1024x128 --axis 2 --cpu")
|
||||||
|
@ -126,7 +126,7 @@ struct ReductionPlan {
|
|||||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
|
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
|
||||||
// The data is all there and we are reducing over everything
|
// The data is all there and we are reducing over everything
|
||||||
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
||||||
(x.flags().row_contiguous || x.flags().col_contiguous)) {
|
x.flags().contiguous) {
|
||||||
return ContiguousAllReduce;
|
return ContiguousAllReduce;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -112,80 +112,22 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||||
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// General reduce
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
|
||||||
[[kernel]] void general_reduce(
|
|
||||||
const device T *in [[buffer(0)]],
|
|
||||||
device mlx_atomic<U> *out [[buffer(1)]],
|
|
||||||
const device int *in_shape [[buffer(2)]],
|
|
||||||
const device size_t *in_strides [[buffer(3)]],
|
|
||||||
const device size_t *out_strides [[buffer(4)]],
|
|
||||||
const device size_t& ndim [[buffer(5)]],
|
|
||||||
uint gid [[thread_position_in_grid]]) {
|
|
||||||
Op op;
|
|
||||||
auto in_idx = elem_to_loc(gid, in_shape, in_strides, ndim);
|
|
||||||
auto out_idx = elem_to_loc(gid, in_shape, out_strides, ndim);
|
|
||||||
op.atomic_update(out, static_cast<U>(in[in_idx]), out_idx);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U, typename Op, int NDIM>
|
|
||||||
[[kernel]] void general_reduce(
|
|
||||||
const device T *in [[buffer(0)]],
|
|
||||||
device mlx_atomic<U> *out [[buffer(1)]],
|
|
||||||
const device int *in_shape [[buffer(2)]],
|
|
||||||
const device size_t *in_strides [[buffer(3)]],
|
|
||||||
const device size_t *out_strides [[buffer(4)]],
|
|
||||||
uint gid [[thread_position_in_grid]]) {
|
|
||||||
Op op;
|
|
||||||
auto in_idx = elem_to_loc_nd<NDIM>(gid, in_shape, in_strides);
|
|
||||||
auto out_idx = elem_to_loc_nd<NDIM>(gid, in_shape, out_strides);
|
|
||||||
op.atomic_update(out, static_cast<U>(in[in_idx]), out_idx);
|
|
||||||
}
|
|
||||||
|
|
||||||
#define instantiate_general_reduce_helper(name, itype, otype, op) \
|
|
||||||
template [[host_name("general_reduce_" #name)]] \
|
|
||||||
[[kernel]] void general_reduce<itype, otype, op>( \
|
|
||||||
const device itype *in [[buffer(0)]], \
|
|
||||||
device mlx_atomic<otype> *out [[buffer(1)]], \
|
|
||||||
const device int *in_shape [[buffer(2)]], \
|
|
||||||
const device size_t *in_strides [[buffer(3)]], \
|
|
||||||
const device size_t *out_strides [[buffer(4)]], \
|
|
||||||
const device size_t& ndim [[buffer(5)]], \
|
|
||||||
uint gid [[thread_position_in_grid]]);
|
|
||||||
|
|
||||||
#define instantiate_general_reduce_helper_nd(name, itype, otype, op, n) \
|
|
||||||
template [[host_name("general_reduce_" #name "_dim_" #n)]] \
|
|
||||||
[[kernel]] void general_reduce<itype, otype, op, n>( \
|
|
||||||
const device itype *in [[buffer(0)]], \
|
|
||||||
device mlx_atomic<otype> *out [[buffer(1)]], \
|
|
||||||
const device int *in_shape [[buffer(2)]], \
|
|
||||||
const device size_t *in_strides [[buffer(3)]], \
|
|
||||||
const device size_t *out_strides [[buffer(4)]], \
|
|
||||||
uint gid [[thread_position_in_grid]]);
|
|
||||||
|
|
||||||
#define instantiate_general_reduce(name, itype, otype, op) \
|
|
||||||
instantiate_general_reduce_helper(name, itype, otype, op) \
|
|
||||||
instantiate_general_reduce_helper_nd(name, itype, otype, op, 1) \
|
|
||||||
instantiate_general_reduce_helper_nd(name, itype, otype, op, 2) \
|
|
||||||
instantiate_general_reduce_helper_nd(name, itype, otype, op, 3) \
|
|
||||||
instantiate_general_reduce_helper_nd(name, itype, otype, op, 4)
|
|
||||||
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Row atomics
|
// Row atomics
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||||
[[kernel]] void row_reduce(
|
[[kernel]] void row_reduce_general(
|
||||||
const device T *in [[buffer(0)]],
|
const device T *in [[buffer(0)]],
|
||||||
device U *out [[buffer(1)]],
|
device mlx_atomic<U> *out [[buffer(1)]],
|
||||||
const device size_t& reduction_size [[buffer(2)]],
|
const constant size_t& reduction_size [[buffer(2)]],
|
||||||
uint lid [[thread_position_in_threadgroup]],
|
const constant size_t& out_size [[buffer(3)]],
|
||||||
uint lsize [[threads_per_threadgroup]],
|
const constant int* shape [[buffer(4)]],
|
||||||
uint tid [[threadgroup_position_in_grid]],
|
const constant size_t* strides [[buffer(5)]],
|
||||||
|
const constant int& ndim [[buffer(6)]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
|
uint3 lsize [[threads_per_threadgroup]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||||
@ -193,7 +135,10 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|||||||
Op op;
|
Op op;
|
||||||
|
|
||||||
// Each threadgroup handles 1 reduction
|
// Each threadgroup handles 1 reduction
|
||||||
in += tid * reduction_size + lid * N_READS;
|
// TODO: Specializing elem_to_loc would be slightly faster
|
||||||
|
int idx = tid.y * out_size + tid.x;
|
||||||
|
int extra_offset = elem_to_loc(idx, shape, strides, ndim);
|
||||||
|
in += extra_offset + lid.x * N_READS;
|
||||||
|
|
||||||
// The reduction is accumulated here
|
// The reduction is accumulated here
|
||||||
U total_val = Op::init;
|
U total_val = Op::init;
|
||||||
@ -201,7 +146,7 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|||||||
|
|
||||||
// Loop over the reduction size within thread group
|
// Loop over the reduction size within thread group
|
||||||
int r = 0;
|
int r = 0;
|
||||||
for (; r < (int)ceildiv(reduction_size, N_READS*lsize) - 1; r++) {
|
for (; r < (int)ceildiv(reduction_size, N_READS*lsize.x) - 1; r++) {
|
||||||
T vals[N_READS];
|
T vals[N_READS];
|
||||||
for(int i = 0; i < N_READS; i++) {
|
for(int i = 0; i < N_READS; i++) {
|
||||||
vals[i] = in[i];
|
vals[i] = in[i];
|
||||||
@ -210,11 +155,11 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|||||||
total_val = op(static_cast<U>(vals[i]), total_val);
|
total_val = op(static_cast<U>(vals[i]), total_val);
|
||||||
}
|
}
|
||||||
|
|
||||||
in += lsize * N_READS;
|
in += lsize.x * N_READS;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sepate case for the last set as we close the reduction size
|
// Separate case for the last set as we close the reduction size
|
||||||
size_t reduction_index = (lid + (size_t)lsize * r) * N_READS;
|
size_t reduction_index = (lid.x + (size_t)lsize.x * r) * N_READS;
|
||||||
if(reduction_index < reduction_size) {
|
if(reduction_index < reduction_size) {
|
||||||
int max_reads = reduction_size - reduction_index;
|
int max_reads = reduction_size - reduction_index;
|
||||||
|
|
||||||
@ -240,24 +185,28 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|||||||
// Reduction within thread group
|
// Reduction within thread group
|
||||||
// Only needed if multiple simd groups
|
// Only needed if multiple simd groups
|
||||||
if(reduction_size > simd_size) {
|
if(reduction_size > simd_size) {
|
||||||
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
|
total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init;
|
||||||
total_val = op.simd_reduce(total_val);
|
total_val = op.simd_reduce(total_val);
|
||||||
}
|
}
|
||||||
// Update output
|
// Update output
|
||||||
if (lid == 0) {
|
if (lid.x == 0) {
|
||||||
out[tid] = total_val;
|
op.atomic_update(out, total_val, tid.x);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_row_reduce(name, itype, otype, op) \
|
#define instantiate_row_reduce_general(name, itype, otype, op) \
|
||||||
template [[host_name("row_reduce_" #name)]] \
|
template [[host_name("row_reduce_general_" #name)]] \
|
||||||
[[kernel]] void row_reduce<itype, otype, op>( \
|
[[kernel]] void row_reduce_general<itype, otype, op>( \
|
||||||
const device itype *in [[buffer(0)]], \
|
const device itype *in [[buffer(0)]], \
|
||||||
device otype *out [[buffer(1)]], \
|
device mlx_atomic<otype> *out [[buffer(1)]], \
|
||||||
const device size_t& reduction_size [[buffer(2)]], \
|
const constant size_t& reduction_size [[buffer(2)]], \
|
||||||
uint lid [[thread_position_in_threadgroup]], \
|
const constant size_t& out_size [[buffer(3)]], \
|
||||||
uint lsize [[threads_per_threadgroup]], \
|
const constant int* shape [[buffer(4)]], \
|
||||||
uint tid [[threadgroup_position_in_grid]], \
|
const constant size_t* strides [[buffer(5)]], \
|
||||||
|
const constant int& ndim [[buffer(6)]], \
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]], \
|
||||||
|
uint3 lsize [[threads_per_threadgroup]], \
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||||
@ -311,62 +260,26 @@ inline void _contiguous_strided_reduce(
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||||
[[kernel]] void col_reduce(
|
[[kernel]] void col_reduce_general(
|
||||||
const device T *in [[buffer(0)]],
|
const device T *in [[buffer(0)]],
|
||||||
device mlx_atomic<U> *out [[buffer(1)]],
|
device mlx_atomic<U> *out [[buffer(1)]],
|
||||||
const constant size_t& reduction_size [[buffer(2)]],
|
const constant size_t& reduction_size [[buffer(2)]],
|
||||||
const constant size_t& reduction_stride [[buffer(3)]],
|
const constant size_t& reduction_stride [[buffer(3)]],
|
||||||
const constant size_t& out_size [[buffer(4)]],
|
const constant size_t& out_size [[buffer(4)]],
|
||||||
|
const constant int* shape [[buffer(5)]],
|
||||||
|
const constant size_t* strides [[buffer(6)]],
|
||||||
|
const constant int& ndim [[buffer(7)]],
|
||||||
threadgroup U *local_data [[threadgroup(0)]],
|
threadgroup U *local_data [[threadgroup(0)]],
|
||||||
uint2 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint2 lid [[thread_position_in_threadgroup]],
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
uint2 lsize [[threads_per_threadgroup]]) {
|
uint3 lsize [[threads_per_threadgroup]]) {
|
||||||
auto out_idx = tid.x * lsize.x + lid.x;
|
auto out_idx = tid.x * lsize.x + lid.x;
|
||||||
|
auto in_idx = elem_to_loc(
|
||||||
if(out_idx < out_size) {
|
out_idx + tid.z * out_size,
|
||||||
_contiguous_strided_reduce<T, U, Op, N_READS>(
|
shape,
|
||||||
in,
|
strides,
|
||||||
out,
|
ndim
|
||||||
local_data,
|
);
|
||||||
out_idx,
|
|
||||||
out_idx,
|
|
||||||
reduction_size,
|
|
||||||
reduction_stride,
|
|
||||||
tid,
|
|
||||||
lid,
|
|
||||||
lsize);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#define instantiate_col_reduce(name, itype, otype, op) \
|
|
||||||
template [[host_name("col_reduce_" #name)]] \
|
|
||||||
[[kernel]] void col_reduce<itype, otype, op>( \
|
|
||||||
const device itype *in [[buffer(0)]], \
|
|
||||||
device mlx_atomic<otype> *out [[buffer(1)]], \
|
|
||||||
const constant size_t& reduction_size [[buffer(2)]], \
|
|
||||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
|
||||||
const constant size_t& out_size [[buffer(4)]], \
|
|
||||||
threadgroup otype *local_data [[threadgroup(0)]], \
|
|
||||||
uint2 tid [[threadgroup_position_in_grid]], \
|
|
||||||
uint2 lid [[thread_position_in_threadgroup]], \
|
|
||||||
uint2 lsize [[threads_per_threadgroup]]);
|
|
||||||
|
|
||||||
template <typename T, typename U, typename Op, int NDIM, int N_READS = 16>
|
|
||||||
[[kernel]] void contiguous_strided_reduce(
|
|
||||||
const device T *in [[buffer(0)]],
|
|
||||||
device mlx_atomic<U> *out [[buffer(1)]],
|
|
||||||
const constant size_t& reduction_size [[buffer(2)]],
|
|
||||||
const constant size_t& reduction_stride [[buffer(3)]],
|
|
||||||
const constant size_t& out_size [[buffer(4)]],
|
|
||||||
const device int* in_shape [[buffer(5)]],
|
|
||||||
const device size_t* in_strides [[buffer(6)]],
|
|
||||||
threadgroup U *local_data [[threadgroup(0)]],
|
|
||||||
uint2 tid [[threadgroup_position_in_grid]],
|
|
||||||
uint2 lid [[thread_position_in_threadgroup]],
|
|
||||||
uint2 lsize [[threads_per_threadgroup]]) {
|
|
||||||
|
|
||||||
auto out_idx = tid.x * lsize.x + lid.x;
|
|
||||||
auto in_idx = elem_to_loc_nd<NDIM>(out_idx, in_shape, in_strides);
|
|
||||||
|
|
||||||
if(out_idx < out_size) {
|
if(out_idx < out_size) {
|
||||||
_contiguous_strided_reduce<T, U, Op, N_READS>(
|
_contiguous_strided_reduce<T, U, Op, N_READS>(
|
||||||
@ -377,82 +290,27 @@ template <typename T, typename U, typename Op, int NDIM, int N_READS = 16>
|
|||||||
out_idx,
|
out_idx,
|
||||||
reduction_size,
|
reduction_size,
|
||||||
reduction_stride,
|
reduction_stride,
|
||||||
tid,
|
tid.xy,
|
||||||
lid,
|
lid.xy,
|
||||||
lsize);
|
lsize.xy);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
#define instantiate_col_reduce_general(name, itype, otype, op) \
|
||||||
[[kernel]] void contiguous_strided_reduce(
|
template [[host_name("col_reduce_general_" #name)]] \
|
||||||
const device T *in [[buffer(0)]],
|
[[kernel]] void col_reduce_general<itype, otype, op>( \
|
||||||
device mlx_atomic<U> *out [[buffer(1)]],
|
|
||||||
const constant size_t& reduction_size [[buffer(2)]],
|
|
||||||
const constant size_t& reduction_stride [[buffer(3)]],
|
|
||||||
const constant size_t& out_size [[buffer(4)]],
|
|
||||||
const device int* in_shape [[buffer(5)]],
|
|
||||||
const device size_t* in_strides [[buffer(6)]],
|
|
||||||
const device size_t& in_dim [[buffer(7)]],
|
|
||||||
threadgroup U *local_data [[threadgroup(0)]],
|
|
||||||
uint2 tid [[threadgroup_position_in_grid]],
|
|
||||||
uint2 lid [[thread_position_in_threadgroup]],
|
|
||||||
uint2 lsize [[threads_per_threadgroup]]) {
|
|
||||||
|
|
||||||
auto out_idx = tid.x * lsize.x + lid.x;
|
|
||||||
auto in_idx = elem_to_loc(out_idx, in_shape, in_strides, in_dim);
|
|
||||||
|
|
||||||
if(out_idx < out_size) {
|
|
||||||
_contiguous_strided_reduce<T, U, Op, N_READS>(
|
|
||||||
in,
|
|
||||||
out,
|
|
||||||
local_data,
|
|
||||||
in_idx,
|
|
||||||
out_idx,
|
|
||||||
reduction_size,
|
|
||||||
reduction_stride,
|
|
||||||
tid,
|
|
||||||
lid,
|
|
||||||
lsize);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#define instantiate_contiguous_strided_helper(name, itype, otype, op) \
|
|
||||||
template [[host_name("contiguous_strided_reduce_" #name)]] \
|
|
||||||
[[kernel]] void contiguous_strided_reduce<itype, otype, op>( \
|
|
||||||
const device itype *in [[buffer(0)]], \
|
const device itype *in [[buffer(0)]], \
|
||||||
device mlx_atomic<otype> *out [[buffer(1)]], \
|
device mlx_atomic<otype> *out [[buffer(1)]], \
|
||||||
const constant size_t& reduction_size [[buffer(2)]], \
|
const constant size_t& reduction_size [[buffer(2)]], \
|
||||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||||
const constant size_t& out_size [[buffer(4)]], \
|
const constant size_t& out_size [[buffer(4)]], \
|
||||||
const device int* in_shape [[buffer(5)]], \
|
const constant int* shape [[buffer(5)]], \
|
||||||
const device size_t* in_strides [[buffer(6)]], \
|
const constant size_t* strides [[buffer(6)]], \
|
||||||
const device size_t& in_dim [[buffer(7)]], \
|
const constant int& ndim [[buffer(7)]], \
|
||||||
threadgroup otype *local_data [[threadgroup(0)]], \
|
threadgroup otype *local_data [[threadgroup(0)]], \
|
||||||
uint2 tid [[threadgroup_position_in_grid]], \
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
uint2 lid [[thread_position_in_threadgroup]], \
|
uint3 lid [[thread_position_in_threadgroup]], \
|
||||||
uint2 lsize [[threads_per_threadgroup]]);
|
uint3 lsize [[threads_per_threadgroup]]);
|
||||||
|
|
||||||
#define instantiate_contiguous_strided_helper_nd(name, itype, otype, op, n) \
|
|
||||||
template [[host_name("contiguous_strided_reduce_" #name "_dim_" #n)]] \
|
|
||||||
[[kernel]] void contiguous_strided_reduce<itype, otype, op, n>( \
|
|
||||||
const device itype *in [[buffer(0)]], \
|
|
||||||
device mlx_atomic<otype> *out [[buffer(1)]], \
|
|
||||||
const constant size_t& reduction_size [[buffer(2)]], \
|
|
||||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
|
||||||
const constant size_t& out_size [[buffer(4)]], \
|
|
||||||
const device int* in_shape [[buffer(5)]], \
|
|
||||||
const device size_t* in_strides [[buffer(6)]], \
|
|
||||||
threadgroup otype *local_data [[threadgroup(0)]], \
|
|
||||||
uint2 tid [[threadgroup_position_in_grid]], \
|
|
||||||
uint2 lid [[thread_position_in_threadgroup]], \
|
|
||||||
uint2 lsize [[threads_per_threadgroup]]);
|
|
||||||
|
|
||||||
#define instantiate_contiguous_strided(name, itype, otype, op) \
|
|
||||||
instantiate_contiguous_strided_helper(name, itype, otype, op) \
|
|
||||||
instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 1) \
|
|
||||||
instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 2) \
|
|
||||||
instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 3) \
|
|
||||||
instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 4)
|
|
||||||
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
@ -461,10 +319,8 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
|||||||
|
|
||||||
#define instantiate_reduce(name, itype, otype, op) \
|
#define instantiate_reduce(name, itype, otype, op) \
|
||||||
instantiate_all_reduce(name, itype, otype, op) \
|
instantiate_all_reduce(name, itype, otype, op) \
|
||||||
instantiate_row_reduce(name, itype, otype, op) \
|
instantiate_row_reduce_general(name, itype, otype, op) \
|
||||||
instantiate_col_reduce(name, itype, otype, op) \
|
instantiate_col_reduce_general(name, itype, otype, op)
|
||||||
instantiate_contiguous_strided(name, itype, otype, op) \
|
|
||||||
instantiate_general_reduce(name, itype, otype, op)
|
|
||||||
|
|
||||||
#define instantiate_same_reduce(name, tname, type, op) \
|
#define instantiate_same_reduce(name, tname, type, op) \
|
||||||
instantiate_init_reduce(name ##tname, type, op<type>) \
|
instantiate_init_reduce(name ##tname, type, op<type>) \
|
||||||
|
@ -2,9 +2,11 @@
|
|||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
#include <iostream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#include "mlx/backend/common/reduce.h"
|
#include "mlx/backend/common/reduce.h"
|
||||||
|
#include "mlx/backend/metal/copy.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/backend/metal/kernels/defines.h"
|
#include "mlx/backend/metal/kernels/defines.h"
|
||||||
#include "mlx/backend/metal/utils.h"
|
#include "mlx/backend/metal/utils.h"
|
||||||
@ -61,22 +63,47 @@ void all_reduce_dispatch(
|
|||||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
void row_reduce_dispatch(
|
void row_reduce_general_dispatch(
|
||||||
const array& in,
|
const array& in,
|
||||||
array& out,
|
array& out,
|
||||||
const std::string& op_name,
|
const std::string& op_name,
|
||||||
const std::vector<int>& axes_,
|
const ReductionPlan& plan,
|
||||||
|
const std::vector<int>& axes,
|
||||||
MTL::ComputeCommandEncoder* compute_encoder,
|
MTL::ComputeCommandEncoder* compute_encoder,
|
||||||
metal::Device& d) {
|
metal::Device& d) {
|
||||||
auto kernel = d.get_kernel("row_reduce_" + op_name + type_to_name(in));
|
auto kernel =
|
||||||
|
d.get_kernel("row_reduce_general_" + op_name + type_to_name(in));
|
||||||
|
|
||||||
|
// Prepare the arguments for the kernel
|
||||||
int n_reads = REDUCE_N_READS;
|
int n_reads = REDUCE_N_READS;
|
||||||
size_t reduction_size = in.size() / out.size();
|
size_t reduction_size = plan.shape.back();
|
||||||
|
size_t out_size = out.size();
|
||||||
|
auto shape = plan.shape;
|
||||||
|
auto strides = plan.strides;
|
||||||
|
shape.pop_back();
|
||||||
|
strides.pop_back();
|
||||||
|
size_t non_row_reductions = 1;
|
||||||
|
for (auto s : shape) {
|
||||||
|
non_row_reductions *= static_cast<size_t>(s);
|
||||||
|
}
|
||||||
|
auto [rem_shape, rem_strides] = shapes_without_reduction_axes(in, axes);
|
||||||
|
for (auto s : rem_shape) {
|
||||||
|
shape.push_back(s);
|
||||||
|
}
|
||||||
|
for (auto s : rem_strides) {
|
||||||
|
strides.push_back(s);
|
||||||
|
}
|
||||||
|
int ndim = shape.size();
|
||||||
|
|
||||||
|
// Set the arguments for the kernel
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
set_array_buffer(compute_encoder, in, 0);
|
set_array_buffer(compute_encoder, in, 0);
|
||||||
set_array_buffer(compute_encoder, out, 1);
|
set_array_buffer(compute_encoder, out, 1);
|
||||||
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
||||||
|
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
|
||||||
|
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4);
|
||||||
|
compute_encoder->setBytes(strides.data(), strides.size() * sizeof(size_t), 5);
|
||||||
|
compute_encoder->setBytes(&ndim, sizeof(int), 6);
|
||||||
|
|
||||||
// Each thread group is responsible for 1 output
|
// Each thread group is responsible for 1 output
|
||||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
@ -91,92 +118,54 @@ void row_reduce_dispatch(
|
|||||||
|
|
||||||
// Launch enough thread groups for each output
|
// Launch enough thread groups for each output
|
||||||
size_t n_threads = out.size() * thread_group_size;
|
size_t n_threads = out.size() * thread_group_size;
|
||||||
MTL::Size grid_dims = MTL::Size(n_threads, 1, 1);
|
MTL::Size grid_dims = MTL::Size(n_threads, non_row_reductions, 1);
|
||||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||||
|
|
||||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
void col_reduce_dispatch(
|
void strided_reduce_general_dispatch(
|
||||||
const array& in,
|
const array& in,
|
||||||
array& out,
|
array& out,
|
||||||
const std::string& op_name,
|
const std::string& op_name,
|
||||||
const std::vector<int>& axes_,
|
const ReductionPlan& plan,
|
||||||
|
const std::vector<int>& axes,
|
||||||
MTL::ComputeCommandEncoder* compute_encoder,
|
MTL::ComputeCommandEncoder* compute_encoder,
|
||||||
metal::Device& d) {
|
metal::Device& d) {
|
||||||
std::ostringstream kernel_name;
|
auto kernel =
|
||||||
|
d.get_kernel("col_reduce_general_" + op_name + type_to_name(in));
|
||||||
|
|
||||||
bool encode_in_shape = false;
|
// Prepare the arguments for the kernel
|
||||||
bool encode_ndim = false;
|
size_t reduction_size = plan.shape.back();
|
||||||
|
size_t reduction_stride = plan.strides.back();
|
||||||
// If the slowest moving axis can be merged into the reductions,
|
|
||||||
// we call the column reduce kernel
|
|
||||||
// In this case, a linear index in the output corresponds to the
|
|
||||||
// linear index in the input where the reduction starts
|
|
||||||
if (axes_[axes_.size() - 1] == (axes_.size() - 1)) {
|
|
||||||
kernel_name << "col_reduce_" << op_name << type_to_name(in);
|
|
||||||
}
|
|
||||||
// Otherwise, while all the reduction axes can be merged, the mapping between
|
|
||||||
// indices in the output and input require resolving using shapes and strides
|
|
||||||
else {
|
|
||||||
kernel_name << "contiguous_strided_reduce_" << op_name << type_to_name(in);
|
|
||||||
encode_in_shape = true;
|
|
||||||
|
|
||||||
// We check for a viable template with the required number of dimensions
|
|
||||||
// we only care about encoding non-reduced shapes and strides in the input
|
|
||||||
size_t non_reducing_dims = in.ndim() - axes_.size();
|
|
||||||
if (non_reducing_dims >= 1 &&
|
|
||||||
non_reducing_dims <= MAX_REDUCE_SPECIALIZED_DIMS) {
|
|
||||||
kernel_name << "_dim_" << non_reducing_dims;
|
|
||||||
} else {
|
|
||||||
encode_ndim = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
auto kernel = d.get_kernel(kernel_name.str());
|
|
||||||
size_t in_size = in.size();
|
|
||||||
size_t out_size = out.size();
|
size_t out_size = out.size();
|
||||||
|
auto shape = plan.shape;
|
||||||
|
auto strides = plan.strides;
|
||||||
|
shape.pop_back();
|
||||||
|
strides.pop_back();
|
||||||
|
size_t non_col_reductions = 1;
|
||||||
|
for (auto s : shape) {
|
||||||
|
non_col_reductions *= static_cast<size_t>(s);
|
||||||
|
}
|
||||||
|
auto [rem_shape, rem_strides] = shapes_without_reduction_axes(in, axes);
|
||||||
|
for (auto s : rem_shape) {
|
||||||
|
shape.push_back(s);
|
||||||
|
}
|
||||||
|
for (auto s : rem_strides) {
|
||||||
|
strides.push_back(s);
|
||||||
|
}
|
||||||
|
int ndim = shape.size();
|
||||||
|
|
||||||
|
// Set the arguments for the kernel
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
set_array_buffer(compute_encoder, in, 0);
|
set_array_buffer(compute_encoder, in, 0);
|
||||||
set_array_buffer(compute_encoder, out, 1);
|
set_array_buffer(compute_encoder, out, 1);
|
||||||
|
|
||||||
// Calculate the number of inputs to reduce and the stride b/w them
|
|
||||||
size_t reduction_size = 1;
|
|
||||||
size_t in_ndim = in.ndim();
|
|
||||||
size_t reduction_stride = in_size;
|
|
||||||
|
|
||||||
for (int i : axes_) {
|
|
||||||
reduction_size *= in.shape(i);
|
|
||||||
reduction_stride = std::min(reduction_stride, in.strides()[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
||||||
compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3);
|
compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3);
|
||||||
compute_encoder->setBytes(&out_size, sizeof(size_t), 4);
|
compute_encoder->setBytes(&out_size, sizeof(size_t), 4);
|
||||||
if (encode_in_shape) {
|
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5);
|
||||||
// Obtain the non-reducing shape and strides of the input to encode
|
compute_encoder->setBytes(strides.data(), strides.size() * sizeof(size_t), 6);
|
||||||
std::vector<int> inp_shape_mod;
|
compute_encoder->setBytes(&ndim, sizeof(int), 7);
|
||||||
std::vector<size_t> inp_strides_mod;
|
|
||||||
|
|
||||||
for (size_t i = 0, j = 0; i < in.ndim(); i++) {
|
|
||||||
if (j < axes_.size() && axes_[j] == i) {
|
|
||||||
j++;
|
|
||||||
} else {
|
|
||||||
inp_shape_mod.push_back(in.shape(i));
|
|
||||||
inp_strides_mod.push_back(in.strides()[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t ndim = inp_shape_mod.size();
|
|
||||||
|
|
||||||
compute_encoder->setBytes(inp_shape_mod.data(), ndim * sizeof(int), 5);
|
|
||||||
compute_encoder->setBytes(inp_strides_mod.data(), ndim * sizeof(size_t), 6);
|
|
||||||
|
|
||||||
if (encode_ndim) {
|
|
||||||
compute_encoder->setBytes(&ndim, sizeof(size_t), 7);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Select block dimensions
|
// Select block dimensions
|
||||||
|
|
||||||
@ -200,7 +189,8 @@ void col_reduce_dispatch(
|
|||||||
(n_threads_per_output + threadgroup_dim_y - 1) / threadgroup_dim_y;
|
(n_threads_per_output + threadgroup_dim_y - 1) / threadgroup_dim_y;
|
||||||
|
|
||||||
// Launch enough thread groups for each output
|
// Launch enough thread groups for each output
|
||||||
MTL::Size grid_dims = MTL::Size(n_threadgroups_x, n_threadgroups_y, 1);
|
MTL::Size grid_dims =
|
||||||
|
MTL::Size(n_threadgroups_x, n_threadgroups_y, non_col_reductions);
|
||||||
MTL::Size group_dims = MTL::Size(threadgroup_dim_x, threadgroup_dim_y, 1);
|
MTL::Size group_dims = MTL::Size(threadgroup_dim_x, threadgroup_dim_y, 1);
|
||||||
|
|
||||||
// We set shared memory to be exploited here for reductions within a
|
// We set shared memory to be exploited here for reductions within a
|
||||||
@ -216,60 +206,6 @@ void col_reduce_dispatch(
|
|||||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
void general_reduce_dispatch(
|
|
||||||
const array& in,
|
|
||||||
array& out,
|
|
||||||
const std::string& op_name,
|
|
||||||
const std::vector<int>& axes_,
|
|
||||||
MTL::ComputeCommandEncoder* compute_encoder,
|
|
||||||
metal::Device& d) {
|
|
||||||
bool encode_ndim = true;
|
|
||||||
std::ostringstream kernel_name;
|
|
||||||
kernel_name << "general_reduce_" << op_name << type_to_name(in);
|
|
||||||
|
|
||||||
// Check for specialzed kernels for input ndim
|
|
||||||
if (in.ndim() >= 1 && in.ndim() <= MAX_REDUCE_SPECIALIZED_DIMS) {
|
|
||||||
kernel_name << "_dim_" << in.ndim();
|
|
||||||
encode_ndim = false;
|
|
||||||
}
|
|
||||||
auto kernel = d.get_kernel(kernel_name.str());
|
|
||||||
size_t in_size = in.size();
|
|
||||||
size_t ndim = in.ndim();
|
|
||||||
|
|
||||||
// We set the reducing strides to 0 to induce collisions for the reduction
|
|
||||||
std::vector<size_t> out_strides(ndim);
|
|
||||||
size_t stride = 1;
|
|
||||||
for (int i = ndim - 1, j = axes_.size() - 1; i >= 0; --i) {
|
|
||||||
if (j >= 0 && axes_[j] == i) {
|
|
||||||
out_strides[i] = 0;
|
|
||||||
--j;
|
|
||||||
} else {
|
|
||||||
out_strides[i] = stride;
|
|
||||||
stride *= in.shape(i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
|
||||||
set_array_buffer(compute_encoder, in, 0);
|
|
||||||
set_array_buffer(compute_encoder, out, 1);
|
|
||||||
compute_encoder->setBytes(in.shape().data(), ndim * sizeof(int), 2);
|
|
||||||
compute_encoder->setBytes(in.strides().data(), ndim * sizeof(size_t), 3);
|
|
||||||
compute_encoder->setBytes(out_strides.data(), ndim * sizeof(size_t), 4);
|
|
||||||
if (encode_ndim) {
|
|
||||||
compute_encoder->setBytes(&ndim, sizeof(size_t), 5);
|
|
||||||
}
|
|
||||||
|
|
||||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
|
||||||
if (thread_group_size > in_size) {
|
|
||||||
thread_group_size = in_size;
|
|
||||||
}
|
|
||||||
size_t nthreads = in_size;
|
|
||||||
|
|
||||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
|
||||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
|
||||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
@ -278,7 +214,7 @@ void general_reduce_dispatch(
|
|||||||
|
|
||||||
void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
auto& in = inputs[0];
|
array in = inputs[0];
|
||||||
|
|
||||||
// TODO: Allow specific row and column reductions with types disabled
|
// TODO: Allow specific row and column reductions with types disabled
|
||||||
// due to atomics ?
|
// due to atomics ?
|
||||||
@ -335,37 +271,47 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
// Reduce
|
// Reduce
|
||||||
{
|
{
|
||||||
// Check for contiguous data
|
std::vector<array> copies;
|
||||||
if (in.size() == in.data_size() &&
|
ReductionPlan plan = get_reduction_plan(in, axes_);
|
||||||
(in.flags().row_contiguous || in.flags().col_contiguous)) {
|
|
||||||
// Go to all reduce if reducing over all axes
|
// If it is a general reduce then copy the input to a contiguous array and
|
||||||
if (axes_.size() == in.ndim()) {
|
// recompute the plan.
|
||||||
|
if (plan.type == GeneralReduce) {
|
||||||
|
array in_copy(in.shape(), in.dtype(), nullptr, {});
|
||||||
|
copy_gpu(in, in_copy, CopyType::General, s);
|
||||||
|
copies.push_back(in_copy);
|
||||||
|
in = in_copy;
|
||||||
|
plan = get_reduction_plan(in, axes_);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reducing over everything and the data is all there no broadcasting or
|
||||||
|
// slicing etc.
|
||||||
|
if (plan.type == ContiguousAllReduce) {
|
||||||
all_reduce_dispatch(in, out, op_name, compute_encoder, d);
|
all_reduce_dispatch(in, out, op_name, compute_encoder, d);
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
// Use specialized kernels if the input is row contiguous and
|
|
||||||
// the reducing axes can be merged into one
|
// At least the last dimension is row contiguous and we are reducing over
|
||||||
|
// the last dim.
|
||||||
else if (
|
else if (
|
||||||
in.flags().row_contiguous && in.strides().back() == 1 &&
|
plan.type == ContiguousReduce || plan.type == GeneralContiguousReduce) {
|
||||||
(axes_.back() - axes_.front()) == axes_.size() - 1) {
|
row_reduce_general_dispatch(
|
||||||
// If the fastest moving axis is being reduced, go to row reduce
|
in, out, op_name, plan, axes_, compute_encoder, d);
|
||||||
if (axes_[0] == (in.ndim() - axes_.size())) {
|
|
||||||
row_reduce_dispatch(in, out, op_name, axes_, compute_encoder, d);
|
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
// Otherwise go to to generalized strided reduce
|
|
||||||
// Note: bool isn't support here yet due to the use of atomics
|
// At least the last two dimensions are contiguous and we are doing a
|
||||||
// once that is updated, this should be the else condition of this
|
// strided reduce over these.
|
||||||
// branch
|
else if (
|
||||||
else if (in.dtype() != bool_) {
|
plan.type == ContiguousStridedReduce ||
|
||||||
col_reduce_dispatch(in, out, op_name, axes_, compute_encoder, d);
|
plan.type == GeneralStridedReduce) {
|
||||||
return;
|
strided_reduce_general_dispatch(
|
||||||
|
in, out, op_name, plan, axes_, compute_encoder, d);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!copies.empty()) {
|
||||||
|
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||||
|
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Fall back to the general case
|
|
||||||
general_reduce_dispatch(in, out, op_name, axes_, compute_encoder, d);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
Loading…
Reference in New Issue
Block a user