mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Refactor the reduction kernels (#277)
This commit is contained in:
parent
22fee5a383
commit
9e6b8c9f48
@ -125,6 +125,14 @@ if __name__ == "__main__":
|
||||
compare_filtered("sum_axis --size 16x128x1024 --axis 1")
|
||||
compare_filtered("sum_axis --size 16x128x1024 --axis 0 --cpu")
|
||||
compare_filtered("sum_axis --size 16x128x1024 --axis 0")
|
||||
compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --cpu")
|
||||
compare_filtered("sum_axis --size 16x128x1024 --axis 0,1")
|
||||
compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --cpu")
|
||||
compare_filtered("sum_axis --size 16x128x1024 --axis 0,2")
|
||||
compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --transpose 0,2,1 --cpu")
|
||||
compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --transpose 0,2,1")
|
||||
compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --transpose 0,2,1 --cpu")
|
||||
compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --transpose 0,2,1")
|
||||
compare_filtered("argmax --size 10x1024x128 --axis 1 --cpu")
|
||||
compare_filtered("argmax --size 10x1024x128 --axis 1")
|
||||
compare_filtered("argmax --size 10x1024x128 --axis 2 --cpu")
|
||||
|
@ -126,7 +126,7 @@ struct ReductionPlan {
|
||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
|
||||
// The data is all there and we are reducing over everything
|
||||
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
||||
(x.flags().row_contiguous || x.flags().col_contiguous)) {
|
||||
x.flags().contiguous) {
|
||||
return ContiguousAllReduce;
|
||||
}
|
||||
|
||||
|
@ -112,88 +112,33 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// General reduce
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void general_reduce(
|
||||
const device T *in [[buffer(0)]],
|
||||
device mlx_atomic<U> *out [[buffer(1)]],
|
||||
const device int *in_shape [[buffer(2)]],
|
||||
const device size_t *in_strides [[buffer(3)]],
|
||||
const device size_t *out_strides [[buffer(4)]],
|
||||
const device size_t& ndim [[buffer(5)]],
|
||||
uint gid [[thread_position_in_grid]]) {
|
||||
Op op;
|
||||
auto in_idx = elem_to_loc(gid, in_shape, in_strides, ndim);
|
||||
auto out_idx = elem_to_loc(gid, in_shape, out_strides, ndim);
|
||||
op.atomic_update(out, static_cast<U>(in[in_idx]), out_idx);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int NDIM>
|
||||
[[kernel]] void general_reduce(
|
||||
const device T *in [[buffer(0)]],
|
||||
device mlx_atomic<U> *out [[buffer(1)]],
|
||||
const device int *in_shape [[buffer(2)]],
|
||||
const device size_t *in_strides [[buffer(3)]],
|
||||
const device size_t *out_strides [[buffer(4)]],
|
||||
uint gid [[thread_position_in_grid]]) {
|
||||
Op op;
|
||||
auto in_idx = elem_to_loc_nd<NDIM>(gid, in_shape, in_strides);
|
||||
auto out_idx = elem_to_loc_nd<NDIM>(gid, in_shape, out_strides);
|
||||
op.atomic_update(out, static_cast<U>(in[in_idx]), out_idx);
|
||||
}
|
||||
|
||||
#define instantiate_general_reduce_helper(name, itype, otype, op) \
|
||||
template [[host_name("general_reduce_" #name)]] \
|
||||
[[kernel]] void general_reduce<itype, otype, op>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device mlx_atomic<otype> *out [[buffer(1)]], \
|
||||
const device int *in_shape [[buffer(2)]], \
|
||||
const device size_t *in_strides [[buffer(3)]], \
|
||||
const device size_t *out_strides [[buffer(4)]], \
|
||||
const device size_t& ndim [[buffer(5)]], \
|
||||
uint gid [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_general_reduce_helper_nd(name, itype, otype, op, n) \
|
||||
template [[host_name("general_reduce_" #name "_dim_" #n)]] \
|
||||
[[kernel]] void general_reduce<itype, otype, op, n>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device mlx_atomic<otype> *out [[buffer(1)]], \
|
||||
const device int *in_shape [[buffer(2)]], \
|
||||
const device size_t *in_strides [[buffer(3)]], \
|
||||
const device size_t *out_strides [[buffer(4)]], \
|
||||
uint gid [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_general_reduce(name, itype, otype, op) \
|
||||
instantiate_general_reduce_helper(name, itype, otype, op) \
|
||||
instantiate_general_reduce_helper_nd(name, itype, otype, op, 1) \
|
||||
instantiate_general_reduce_helper_nd(name, itype, otype, op, 2) \
|
||||
instantiate_general_reduce_helper_nd(name, itype, otype, op, 3) \
|
||||
instantiate_general_reduce_helper_nd(name, itype, otype, op, 4)
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Row atomics
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
[[kernel]] void row_reduce(
|
||||
[[kernel]] void row_reduce_general(
|
||||
const device T *in [[buffer(0)]],
|
||||
device U *out [[buffer(1)]],
|
||||
const device size_t& reduction_size [[buffer(2)]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint tid [[threadgroup_position_in_grid]],
|
||||
device mlx_atomic<U> *out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& out_size [[buffer(3)]],
|
||||
const constant int* shape [[buffer(4)]],
|
||||
const constant size_t* strides [[buffer(5)]],
|
||||
const constant int& ndim [[buffer(6)]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 lsize [[threads_per_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
Op op;
|
||||
|
||||
// Each threadgroup handles 1 reduction
|
||||
in += tid * reduction_size + lid * N_READS;
|
||||
// Each threadgroup handles 1 reduction
|
||||
// TODO: Specializing elem_to_loc would be slightly faster
|
||||
int idx = tid.y * out_size + tid.x;
|
||||
int extra_offset = elem_to_loc(idx, shape, strides, ndim);
|
||||
in += extra_offset + lid.x * N_READS;
|
||||
|
||||
// The reduction is accumulated here
|
||||
U total_val = Op::init;
|
||||
@ -201,7 +146,7 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
|
||||
// Loop over the reduction size within thread group
|
||||
int r = 0;
|
||||
for (; r < (int)ceildiv(reduction_size, N_READS*lsize) - 1; r++) {
|
||||
for (; r < (int)ceildiv(reduction_size, N_READS*lsize.x) - 1; r++) {
|
||||
T vals[N_READS];
|
||||
for(int i = 0; i < N_READS; i++) {
|
||||
vals[i] = in[i];
|
||||
@ -210,11 +155,11 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
total_val = op(static_cast<U>(vals[i]), total_val);
|
||||
}
|
||||
|
||||
in += lsize * N_READS;
|
||||
in += lsize.x * N_READS;
|
||||
}
|
||||
|
||||
// Sepate case for the last set as we close the reduction size
|
||||
size_t reduction_index = (lid + (size_t)lsize * r) * N_READS;
|
||||
// Separate case for the last set as we close the reduction size
|
||||
size_t reduction_index = (lid.x + (size_t)lsize.x * r) * N_READS;
|
||||
if(reduction_index < reduction_size) {
|
||||
int max_reads = reduction_size - reduction_index;
|
||||
|
||||
@ -240,26 +185,30 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
// Reduction within thread group
|
||||
// Only needed if multiple simd groups
|
||||
if(reduction_size > simd_size) {
|
||||
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
|
||||
total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init;
|
||||
total_val = op.simd_reduce(total_val);
|
||||
}
|
||||
// Update output
|
||||
if (lid == 0) {
|
||||
out[tid] = total_val;
|
||||
if (lid.x == 0) {
|
||||
op.atomic_update(out, total_val, tid.x);
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_row_reduce(name, itype, otype, op) \
|
||||
template [[host_name("row_reduce_" #name)]] \
|
||||
[[kernel]] void row_reduce<itype, otype, op>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device otype *out [[buffer(1)]], \
|
||||
const device size_t& reduction_size [[buffer(2)]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
#define instantiate_row_reduce_general(name, itype, otype, op) \
|
||||
template [[host_name("row_reduce_general_" #name)]] \
|
||||
[[kernel]] void row_reduce_general<itype, otype, op>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device mlx_atomic<otype> *out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& out_size [[buffer(3)]], \
|
||||
const constant int* shape [[buffer(4)]], \
|
||||
const constant size_t* strides [[buffer(5)]], \
|
||||
const constant int& ndim [[buffer(6)]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 lsize [[threads_per_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
|
||||
@ -311,148 +260,57 @@ inline void _contiguous_strided_reduce(
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void col_reduce(
|
||||
[[kernel]] void col_reduce_general(
|
||||
const device T *in [[buffer(0)]],
|
||||
device mlx_atomic<U> *out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
threadgroup U *local_data [[threadgroup(0)]],
|
||||
uint2 tid [[threadgroup_position_in_grid]],
|
||||
uint2 lid [[thread_position_in_threadgroup]],
|
||||
uint2 lsize [[threads_per_threadgroup]]) {
|
||||
auto out_idx = tid.x * lsize.x + lid.x;
|
||||
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 lsize [[threads_per_threadgroup]]) {
|
||||
auto out_idx = tid.x * lsize.x + lid.x;
|
||||
auto in_idx = elem_to_loc(
|
||||
out_idx + tid.z * out_size,
|
||||
shape,
|
||||
strides,
|
||||
ndim
|
||||
);
|
||||
|
||||
if(out_idx < out_size) {
|
||||
_contiguous_strided_reduce<T, U, Op, N_READS>(
|
||||
in,
|
||||
out,
|
||||
local_data,
|
||||
out_idx,
|
||||
in_idx,
|
||||
out_idx,
|
||||
reduction_size,
|
||||
reduction_stride,
|
||||
tid,
|
||||
lid,
|
||||
lsize);
|
||||
tid.xy,
|
||||
lid.xy,
|
||||
lsize.xy);
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_col_reduce(name, itype, otype, op) \
|
||||
template [[host_name("col_reduce_" #name)]] \
|
||||
[[kernel]] void col_reduce<itype, otype, op>( \
|
||||
#define instantiate_col_reduce_general(name, itype, otype, op) \
|
||||
template [[host_name("col_reduce_general_" #name)]] \
|
||||
[[kernel]] void col_reduce_general<itype, otype, op>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device mlx_atomic<otype> *out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
threadgroup otype *local_data [[threadgroup(0)]], \
|
||||
uint2 tid [[threadgroup_position_in_grid]], \
|
||||
uint2 lid [[thread_position_in_threadgroup]], \
|
||||
uint2 lsize [[threads_per_threadgroup]]);
|
||||
|
||||
template <typename T, typename U, typename Op, int NDIM, int N_READS = 16>
|
||||
[[kernel]] void contiguous_strided_reduce(
|
||||
const device T *in [[buffer(0)]],
|
||||
device mlx_atomic<U> *out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
const device int* in_shape [[buffer(5)]],
|
||||
const device size_t* in_strides [[buffer(6)]],
|
||||
threadgroup U *local_data [[threadgroup(0)]],
|
||||
uint2 tid [[threadgroup_position_in_grid]],
|
||||
uint2 lid [[thread_position_in_threadgroup]],
|
||||
uint2 lsize [[threads_per_threadgroup]]) {
|
||||
|
||||
auto out_idx = tid.x * lsize.x + lid.x;
|
||||
auto in_idx = elem_to_loc_nd<NDIM>(out_idx, in_shape, in_strides);
|
||||
|
||||
if(out_idx < out_size) {
|
||||
_contiguous_strided_reduce<T, U, Op, N_READS>(
|
||||
in,
|
||||
out,
|
||||
local_data,
|
||||
in_idx,
|
||||
out_idx,
|
||||
reduction_size,
|
||||
reduction_stride,
|
||||
tid,
|
||||
lid,
|
||||
lsize);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void contiguous_strided_reduce(
|
||||
const device T *in [[buffer(0)]],
|
||||
device mlx_atomic<U> *out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
const device int* in_shape [[buffer(5)]],
|
||||
const device size_t* in_strides [[buffer(6)]],
|
||||
const device size_t& in_dim [[buffer(7)]],
|
||||
threadgroup U *local_data [[threadgroup(0)]],
|
||||
uint2 tid [[threadgroup_position_in_grid]],
|
||||
uint2 lid [[thread_position_in_threadgroup]],
|
||||
uint2 lsize [[threads_per_threadgroup]]) {
|
||||
|
||||
auto out_idx = tid.x * lsize.x + lid.x;
|
||||
auto in_idx = elem_to_loc(out_idx, in_shape, in_strides, in_dim);
|
||||
|
||||
if(out_idx < out_size) {
|
||||
_contiguous_strided_reduce<T, U, Op, N_READS>(
|
||||
in,
|
||||
out,
|
||||
local_data,
|
||||
in_idx,
|
||||
out_idx,
|
||||
reduction_size,
|
||||
reduction_stride,
|
||||
tid,
|
||||
lid,
|
||||
lsize);
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_contiguous_strided_helper(name, itype, otype, op) \
|
||||
template [[host_name("contiguous_strided_reduce_" #name)]] \
|
||||
[[kernel]] void contiguous_strided_reduce<itype, otype, op>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device mlx_atomic<otype> *out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
const device int* in_shape [[buffer(5)]], \
|
||||
const device size_t* in_strides [[buffer(6)]], \
|
||||
const device size_t& in_dim [[buffer(7)]], \
|
||||
threadgroup otype *local_data [[threadgroup(0)]], \
|
||||
uint2 tid [[threadgroup_position_in_grid]], \
|
||||
uint2 lid [[thread_position_in_threadgroup]], \
|
||||
uint2 lsize [[threads_per_threadgroup]]);
|
||||
|
||||
#define instantiate_contiguous_strided_helper_nd(name, itype, otype, op, n) \
|
||||
template [[host_name("contiguous_strided_reduce_" #name "_dim_" #n)]] \
|
||||
[[kernel]] void contiguous_strided_reduce<itype, otype, op, n>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device mlx_atomic<otype> *out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
const device int* in_shape [[buffer(5)]], \
|
||||
const device size_t* in_strides [[buffer(6)]], \
|
||||
threadgroup otype *local_data [[threadgroup(0)]], \
|
||||
uint2 tid [[threadgroup_position_in_grid]], \
|
||||
uint2 lid [[thread_position_in_threadgroup]], \
|
||||
uint2 lsize [[threads_per_threadgroup]]);
|
||||
|
||||
#define instantiate_contiguous_strided(name, itype, otype, op) \
|
||||
instantiate_contiguous_strided_helper(name, itype, otype, op) \
|
||||
instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 1) \
|
||||
instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 2) \
|
||||
instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 3) \
|
||||
instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 4)
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 lsize [[threads_per_threadgroup]]);
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@ -461,10 +319,8 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
|
||||
#define instantiate_reduce(name, itype, otype, op) \
|
||||
instantiate_all_reduce(name, itype, otype, op) \
|
||||
instantiate_row_reduce(name, itype, otype, op) \
|
||||
instantiate_col_reduce(name, itype, otype, op) \
|
||||
instantiate_contiguous_strided(name, itype, otype, op) \
|
||||
instantiate_general_reduce(name, itype, otype, op)
|
||||
instantiate_row_reduce_general(name, itype, otype, op) \
|
||||
instantiate_col_reduce_general(name, itype, otype, op)
|
||||
|
||||
#define instantiate_same_reduce(name, tname, type, op) \
|
||||
instantiate_init_reduce(name ##tname, type, op<type>) \
|
||||
@ -535,4 +391,4 @@ instantiate_same_reduce(max_, float16, half, Max)
|
||||
instantiate_same_reduce(max_, float32, float, Max)
|
||||
|
||||
instantiate_same_reduce(min_, bfloat16, bfloat16_t, Min)
|
||||
instantiate_same_reduce(max_, bfloat16, bfloat16_t, Max)
|
||||
instantiate_same_reduce(max_, bfloat16, bfloat16_t, Max)
|
||||
|
@ -2,9 +2,11 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/common/reduce.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
@ -61,22 +63,47 @@ void all_reduce_dispatch(
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void row_reduce_dispatch(
|
||||
void row_reduce_general_dispatch(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::string& op_name,
|
||||
const std::vector<int>& axes_,
|
||||
const ReductionPlan& plan,
|
||||
const std::vector<int>& axes,
|
||||
MTL::ComputeCommandEncoder* compute_encoder,
|
||||
metal::Device& d) {
|
||||
auto kernel = d.get_kernel("row_reduce_" + op_name + type_to_name(in));
|
||||
auto kernel =
|
||||
d.get_kernel("row_reduce_general_" + op_name + type_to_name(in));
|
||||
|
||||
// Prepare the arguments for the kernel
|
||||
int n_reads = REDUCE_N_READS;
|
||||
size_t reduction_size = in.size() / out.size();
|
||||
size_t reduction_size = plan.shape.back();
|
||||
size_t out_size = out.size();
|
||||
auto shape = plan.shape;
|
||||
auto strides = plan.strides;
|
||||
shape.pop_back();
|
||||
strides.pop_back();
|
||||
size_t non_row_reductions = 1;
|
||||
for (auto s : shape) {
|
||||
non_row_reductions *= static_cast<size_t>(s);
|
||||
}
|
||||
auto [rem_shape, rem_strides] = shapes_without_reduction_axes(in, axes);
|
||||
for (auto s : rem_shape) {
|
||||
shape.push_back(s);
|
||||
}
|
||||
for (auto s : rem_strides) {
|
||||
strides.push_back(s);
|
||||
}
|
||||
int ndim = shape.size();
|
||||
|
||||
// Set the arguments for the kernel
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4);
|
||||
compute_encoder->setBytes(strides.data(), strides.size() * sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 6);
|
||||
|
||||
// Each thread group is responsible for 1 output
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
@ -91,92 +118,54 @@ void row_reduce_dispatch(
|
||||
|
||||
// Launch enough thread groups for each output
|
||||
size_t n_threads = out.size() * thread_group_size;
|
||||
MTL::Size grid_dims = MTL::Size(n_threads, 1, 1);
|
||||
MTL::Size grid_dims = MTL::Size(n_threads, non_row_reductions, 1);
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void col_reduce_dispatch(
|
||||
void strided_reduce_general_dispatch(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::string& op_name,
|
||||
const std::vector<int>& axes_,
|
||||
const ReductionPlan& plan,
|
||||
const std::vector<int>& axes,
|
||||
MTL::ComputeCommandEncoder* compute_encoder,
|
||||
metal::Device& d) {
|
||||
std::ostringstream kernel_name;
|
||||
auto kernel =
|
||||
d.get_kernel("col_reduce_general_" + op_name + type_to_name(in));
|
||||
|
||||
bool encode_in_shape = false;
|
||||
bool encode_ndim = false;
|
||||
|
||||
// If the slowest moving axis can be merged into the reductions,
|
||||
// we call the column reduce kernel
|
||||
// In this case, a linear index in the output corresponds to the
|
||||
// linear index in the input where the reduction starts
|
||||
if (axes_[axes_.size() - 1] == (axes_.size() - 1)) {
|
||||
kernel_name << "col_reduce_" << op_name << type_to_name(in);
|
||||
}
|
||||
// Otherwise, while all the reduction axes can be merged, the mapping between
|
||||
// indices in the output and input require resolving using shapes and strides
|
||||
else {
|
||||
kernel_name << "contiguous_strided_reduce_" << op_name << type_to_name(in);
|
||||
encode_in_shape = true;
|
||||
|
||||
// We check for a viable template with the required number of dimensions
|
||||
// we only care about encoding non-reduced shapes and strides in the input
|
||||
size_t non_reducing_dims = in.ndim() - axes_.size();
|
||||
if (non_reducing_dims >= 1 &&
|
||||
non_reducing_dims <= MAX_REDUCE_SPECIALIZED_DIMS) {
|
||||
kernel_name << "_dim_" << non_reducing_dims;
|
||||
} else {
|
||||
encode_ndim = true;
|
||||
}
|
||||
}
|
||||
|
||||
auto kernel = d.get_kernel(kernel_name.str());
|
||||
size_t in_size = in.size();
|
||||
// Prepare the arguments for the kernel
|
||||
size_t reduction_size = plan.shape.back();
|
||||
size_t reduction_stride = plan.strides.back();
|
||||
size_t out_size = out.size();
|
||||
auto shape = plan.shape;
|
||||
auto strides = plan.strides;
|
||||
shape.pop_back();
|
||||
strides.pop_back();
|
||||
size_t non_col_reductions = 1;
|
||||
for (auto s : shape) {
|
||||
non_col_reductions *= static_cast<size_t>(s);
|
||||
}
|
||||
auto [rem_shape, rem_strides] = shapes_without_reduction_axes(in, axes);
|
||||
for (auto s : rem_shape) {
|
||||
shape.push_back(s);
|
||||
}
|
||||
for (auto s : rem_strides) {
|
||||
strides.push_back(s);
|
||||
}
|
||||
int ndim = shape.size();
|
||||
|
||||
// Set the arguments for the kernel
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
|
||||
// Calculate the number of inputs to reduce and the stride b/w them
|
||||
size_t reduction_size = 1;
|
||||
size_t in_ndim = in.ndim();
|
||||
size_t reduction_stride = in_size;
|
||||
|
||||
for (int i : axes_) {
|
||||
reduction_size *= in.shape(i);
|
||||
reduction_stride = std::min(reduction_stride, in.strides()[i]);
|
||||
}
|
||||
|
||||
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(&out_size, sizeof(size_t), 4);
|
||||
if (encode_in_shape) {
|
||||
// Obtain the non-reducing shape and strides of the input to encode
|
||||
std::vector<int> inp_shape_mod;
|
||||
std::vector<size_t> inp_strides_mod;
|
||||
|
||||
for (size_t i = 0, j = 0; i < in.ndim(); i++) {
|
||||
if (j < axes_.size() && axes_[j] == i) {
|
||||
j++;
|
||||
} else {
|
||||
inp_shape_mod.push_back(in.shape(i));
|
||||
inp_strides_mod.push_back(in.strides()[i]);
|
||||
}
|
||||
}
|
||||
|
||||
size_t ndim = inp_shape_mod.size();
|
||||
|
||||
compute_encoder->setBytes(inp_shape_mod.data(), ndim * sizeof(int), 5);
|
||||
compute_encoder->setBytes(inp_strides_mod.data(), ndim * sizeof(size_t), 6);
|
||||
|
||||
if (encode_ndim) {
|
||||
compute_encoder->setBytes(&ndim, sizeof(size_t), 7);
|
||||
}
|
||||
}
|
||||
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5);
|
||||
compute_encoder->setBytes(strides.data(), strides.size() * sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 7);
|
||||
|
||||
// Select block dimensions
|
||||
|
||||
@ -200,7 +189,8 @@ void col_reduce_dispatch(
|
||||
(n_threads_per_output + threadgroup_dim_y - 1) / threadgroup_dim_y;
|
||||
|
||||
// Launch enough thread groups for each output
|
||||
MTL::Size grid_dims = MTL::Size(n_threadgroups_x, n_threadgroups_y, 1);
|
||||
MTL::Size grid_dims =
|
||||
MTL::Size(n_threadgroups_x, n_threadgroups_y, non_col_reductions);
|
||||
MTL::Size group_dims = MTL::Size(threadgroup_dim_x, threadgroup_dim_y, 1);
|
||||
|
||||
// We set shared memory to be exploited here for reductions within a
|
||||
@ -216,60 +206,6 @@ void col_reduce_dispatch(
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void general_reduce_dispatch(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::string& op_name,
|
||||
const std::vector<int>& axes_,
|
||||
MTL::ComputeCommandEncoder* compute_encoder,
|
||||
metal::Device& d) {
|
||||
bool encode_ndim = true;
|
||||
std::ostringstream kernel_name;
|
||||
kernel_name << "general_reduce_" << op_name << type_to_name(in);
|
||||
|
||||
// Check for specialzed kernels for input ndim
|
||||
if (in.ndim() >= 1 && in.ndim() <= MAX_REDUCE_SPECIALIZED_DIMS) {
|
||||
kernel_name << "_dim_" << in.ndim();
|
||||
encode_ndim = false;
|
||||
}
|
||||
auto kernel = d.get_kernel(kernel_name.str());
|
||||
size_t in_size = in.size();
|
||||
size_t ndim = in.ndim();
|
||||
|
||||
// We set the reducing strides to 0 to induce collisions for the reduction
|
||||
std::vector<size_t> out_strides(ndim);
|
||||
size_t stride = 1;
|
||||
for (int i = ndim - 1, j = axes_.size() - 1; i >= 0; --i) {
|
||||
if (j >= 0 && axes_[j] == i) {
|
||||
out_strides[i] = 0;
|
||||
--j;
|
||||
} else {
|
||||
out_strides[i] = stride;
|
||||
stride *= in.shape(i);
|
||||
}
|
||||
}
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder->setBytes(in.shape().data(), ndim * sizeof(int), 2);
|
||||
compute_encoder->setBytes(in.strides().data(), ndim * sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(out_strides.data(), ndim * sizeof(size_t), 4);
|
||||
if (encode_ndim) {
|
||||
compute_encoder->setBytes(&ndim, sizeof(size_t), 5);
|
||||
}
|
||||
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > in_size) {
|
||||
thread_group_size = in_size;
|
||||
}
|
||||
size_t nthreads = in_size;
|
||||
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
@ -278,7 +214,7 @@ void general_reduce_dispatch(
|
||||
|
||||
void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
array in = inputs[0];
|
||||
|
||||
// TODO: Allow specific row and column reductions with types disabled
|
||||
// due to atomics ?
|
||||
@ -335,36 +271,46 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Reduce
|
||||
{
|
||||
// Check for contiguous data
|
||||
if (in.size() == in.data_size() &&
|
||||
(in.flags().row_contiguous || in.flags().col_contiguous)) {
|
||||
// Go to all reduce if reducing over all axes
|
||||
if (axes_.size() == in.ndim()) {
|
||||
all_reduce_dispatch(in, out, op_name, compute_encoder, d);
|
||||
return;
|
||||
}
|
||||
// Use specialized kernels if the input is row contiguous and
|
||||
// the reducing axes can be merged into one
|
||||
else if (
|
||||
in.flags().row_contiguous && in.strides().back() == 1 &&
|
||||
(axes_.back() - axes_.front()) == axes_.size() - 1) {
|
||||
// If the fastest moving axis is being reduced, go to row reduce
|
||||
if (axes_[0] == (in.ndim() - axes_.size())) {
|
||||
row_reduce_dispatch(in, out, op_name, axes_, compute_encoder, d);
|
||||
return;
|
||||
}
|
||||
// Otherwise go to to generalized strided reduce
|
||||
// Note: bool isn't support here yet due to the use of atomics
|
||||
// once that is updated, this should be the else condition of this
|
||||
// branch
|
||||
else if (in.dtype() != bool_) {
|
||||
col_reduce_dispatch(in, out, op_name, axes_, compute_encoder, d);
|
||||
return;
|
||||
}
|
||||
}
|
||||
std::vector<array> copies;
|
||||
ReductionPlan plan = get_reduction_plan(in, axes_);
|
||||
|
||||
// If it is a general reduce then copy the input to a contiguous array and
|
||||
// recompute the plan.
|
||||
if (plan.type == GeneralReduce) {
|
||||
array in_copy(in.shape(), in.dtype(), nullptr, {});
|
||||
copy_gpu(in, in_copy, CopyType::General, s);
|
||||
copies.push_back(in_copy);
|
||||
in = in_copy;
|
||||
plan = get_reduction_plan(in, axes_);
|
||||
}
|
||||
|
||||
// Reducing over everything and the data is all there no broadcasting or
|
||||
// slicing etc.
|
||||
if (plan.type == ContiguousAllReduce) {
|
||||
all_reduce_dispatch(in, out, op_name, compute_encoder, d);
|
||||
}
|
||||
|
||||
// At least the last dimension is row contiguous and we are reducing over
|
||||
// the last dim.
|
||||
else if (
|
||||
plan.type == ContiguousReduce || plan.type == GeneralContiguousReduce) {
|
||||
row_reduce_general_dispatch(
|
||||
in, out, op_name, plan, axes_, compute_encoder, d);
|
||||
}
|
||||
|
||||
// At least the last two dimensions are contiguous and we are doing a
|
||||
// strided reduce over these.
|
||||
else if (
|
||||
plan.type == ContiguousStridedReduce ||
|
||||
plan.type == GeneralStridedReduce) {
|
||||
strided_reduce_general_dispatch(
|
||||
in, out, op_name, plan, axes_, compute_encoder, d);
|
||||
}
|
||||
|
||||
if (!copies.empty()) {
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
// Fall back to the general case
|
||||
general_reduce_dispatch(in, out, op_name, axes_, compute_encoder, d);
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user