mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 11:38:06 +08:00
Reduce update (#783)
* Split reduction files to reduce compile times * Add small and medium axis size specializations for row reductions * Add non-row-reduction options for small and med kernels
This commit is contained in:
parent
c096a77b9b
commit
6686e61ca4
@ -380,10 +380,6 @@ if __name__ == "__main__":
|
||||
if len(args.axis) > 1:
|
||||
args.axis.pop(0)
|
||||
|
||||
if args.print_pid:
|
||||
print(os.getpid())
|
||||
input("Press enter to run")
|
||||
|
||||
if args.cpu:
|
||||
mx.set_default_device(mx.cpu)
|
||||
else:
|
||||
@ -406,6 +402,10 @@ if __name__ == "__main__":
|
||||
x = xs[0]
|
||||
axis = args.axis[0]
|
||||
|
||||
if args.print_pid:
|
||||
print(os.getpid())
|
||||
input("Press enter to run")
|
||||
|
||||
if args.benchmark == "matmul_square":
|
||||
print(bench(matmul_square, x))
|
||||
|
||||
|
@ -331,10 +331,6 @@ if __name__ == "__main__":
|
||||
if len(args.axis) > 1:
|
||||
args.axis.pop(0)
|
||||
|
||||
if args.print_pid:
|
||||
print(os.getpid())
|
||||
input("Press enter to run")
|
||||
|
||||
torch.set_num_threads(1)
|
||||
device = "cpu" if args.cpu else "mps"
|
||||
|
||||
@ -354,6 +350,10 @@ if __name__ == "__main__":
|
||||
x = xs[0]
|
||||
axis = args.axis[0]
|
||||
|
||||
if args.print_pid:
|
||||
print(os.getpid())
|
||||
input("Press enter to run")
|
||||
|
||||
if args.benchmark == "matmul_square":
|
||||
print(bench(matmul_square, x))
|
||||
|
||||
|
@ -8,7 +8,6 @@ set(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/defines.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/erf.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.h
|
||||
)
|
||||
@ -24,7 +23,6 @@ set(
|
||||
"gemv"
|
||||
"quantized"
|
||||
"random"
|
||||
"reduce"
|
||||
"rope"
|
||||
"scan"
|
||||
"softmax"
|
||||
@ -68,6 +66,15 @@ foreach(KERNEL ${STEEL_KERNELS})
|
||||
set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR})
|
||||
endforeach()
|
||||
|
||||
file(GLOB_RECURSE REDUCE_KERNELS ${CMAKE_CURRENT_SOURCE_DIR}/reduction/*.metal)
|
||||
file(GLOB_RECURSE REDUCE_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/reduction/*.h)
|
||||
|
||||
foreach(KERNEL ${REDUCE_KERNELS})
|
||||
cmake_path(GET KERNEL STEM TARGET)
|
||||
build_kernel_base(${TARGET} ${KERNEL} "${REDUCE_HEADERS}")
|
||||
set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR})
|
||||
endforeach()
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
|
||||
COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o ${MLX_METAL_PATH}/mlx.metallib
|
||||
|
@ -1,619 +0,0 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <metal_atomic>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/reduce.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
static constant uint8_t simd_size = 32;
|
||||
|
||||
template <typename T, typename Op>
|
||||
[[kernel]] void init_reduce(
|
||||
device T *out [[buffer(0)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
out[tid] = Op::init;
|
||||
}
|
||||
|
||||
#define instantiate_init_reduce(name, otype, op) \
|
||||
template [[host_name("i" #name)]] \
|
||||
[[kernel]] void init_reduce<otype, op>( \
|
||||
device otype *out [[buffer(1)]], \
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// All reduce
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
inline U per_thread_all_reduce(
|
||||
const device T *in,
|
||||
const device size_t& in_size,
|
||||
uint gid,
|
||||
uint grid_size) {
|
||||
Op op;
|
||||
U total_val = Op::init;
|
||||
|
||||
if (gid * N_READS < in_size) {
|
||||
in += gid * N_READS;
|
||||
|
||||
int r = 0;
|
||||
for(; r < (int)ceildiv(in_size, grid_size * N_READS) - 1; r++) {
|
||||
U vals[N_READS] = {op.init};
|
||||
|
||||
for(int i = 0; i < N_READS; i++) {
|
||||
vals[i] = static_cast<U>(in[i]);
|
||||
}
|
||||
for(int i = 0; i < N_READS; i++) {
|
||||
total_val = op(vals[i], total_val);
|
||||
}
|
||||
|
||||
in += grid_size * N_READS;
|
||||
}
|
||||
|
||||
// Separate case for the last set as we close the reduction size
|
||||
size_t curr_idx = (gid + r * (size_t)grid_size) * N_READS;
|
||||
if (curr_idx < in_size) {
|
||||
int max_reads = in_size - curr_idx;
|
||||
T vals[N_READS];
|
||||
|
||||
for(int i = 0, idx = 0; i < N_READS; i++, idx++) {
|
||||
idx = idx < max_reads ? idx : max_reads - 1;
|
||||
vals[i] = in[idx];
|
||||
}
|
||||
for(int i = 0; i < N_READS; i++) {
|
||||
U val = i < max_reads ? vals[i] : Op::init;
|
||||
total_val = op(static_cast<U>(val), total_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return total_val;
|
||||
}
|
||||
|
||||
// NB: This kernel assumes threads_per_threadgroup is at most
|
||||
// 1024. This way with a simd_size of 32, we are guaranteed to
|
||||
// complete the reduction in two steps of simd-level reductions.
|
||||
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
[[kernel]] void all_reduce(
|
||||
const device T *in [[buffer(0)]],
|
||||
device mlx_atomic<U> *out [[buffer(1)]],
|
||||
const device size_t& in_size [[buffer(2)]],
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint grid_size [[threads_per_grid]],
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
Op op;
|
||||
threadgroup U local_vals[simd_size];
|
||||
|
||||
U total_val = per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
|
||||
|
||||
// Reduction within simd group
|
||||
total_val = op.simd_reduce(total_val);
|
||||
if (simd_lane_id == 0) {
|
||||
local_vals[simd_group_id] = total_val;
|
||||
}
|
||||
|
||||
// Reduction within thread group
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
|
||||
total_val = op.simd_reduce(total_val);
|
||||
|
||||
// Reduction across threadgroups
|
||||
if (lid == 0) {
|
||||
op.atomic_update(out, total_val);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
[[kernel]] void all_reduce_no_atomics(
|
||||
const device T *in [[buffer(0)]],
|
||||
device U *out [[buffer(1)]],
|
||||
const device size_t& in_size [[buffer(2)]],
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint grid_size [[threads_per_grid]],
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint thread_group_id [[threadgroup_position_in_grid]]) {
|
||||
|
||||
Op op;
|
||||
threadgroup U local_vals[simd_size];
|
||||
|
||||
U total_val = per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
|
||||
|
||||
// Reduction within simd group (simd_add isn't supported for uint64/int64 types)
|
||||
for (uint16_t lane_offset = simd_size/2; lane_offset > 0; lane_offset /= 2) {
|
||||
total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
|
||||
}
|
||||
// Write simd group reduction results to local memory
|
||||
if (simd_lane_id == 0) {
|
||||
local_vals[simd_group_id] = total_val;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Reduction of simdgroup reduction results within threadgroup.
|
||||
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
|
||||
for (uint16_t lane_offset = simd_size/2; lane_offset > 0; lane_offset /= 2) {
|
||||
total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
|
||||
}
|
||||
|
||||
// Reduction across threadgroups
|
||||
if (lid == 0) {
|
||||
out[thread_group_id] = total_val;
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_all_reduce(name, itype, otype, op) \
|
||||
template [[host_name("all_reduce_" #name)]] \
|
||||
[[kernel]] void all_reduce<itype, otype, op>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device mlx_atomic<otype> *out [[buffer(1)]], \
|
||||
const device size_t& in_size [[buffer(2)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint grid_size [[threads_per_grid]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_all_reduce_no_atomics(name, itype, otype, op) \
|
||||
template [[host_name("all_reduce_no_atomics_" #name)]] \
|
||||
[[kernel]] void all_reduce_no_atomics<itype, otype, op>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device otype *out [[buffer(1)]], \
|
||||
const device size_t& in_size [[buffer(2)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint grid_size [[threads_per_grid]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint thread_group_id [[threadgroup_position_in_grid]]);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Row atomics
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
inline U per_thread_row_reduce(
|
||||
const device T *in,
|
||||
const constant size_t& reduction_size,
|
||||
const constant size_t& out_size,
|
||||
const constant int* shape,
|
||||
const constant size_t* strides,
|
||||
const constant int& ndim,
|
||||
uint lsize_x,
|
||||
uint lid_x,
|
||||
uint2 tid) {
|
||||
|
||||
Op op;
|
||||
|
||||
// Each threadgroup handles 1 reduction
|
||||
// TODO: Specializing elem_to_loc would be slightly faster
|
||||
int idx = tid.y * out_size + tid.x;
|
||||
int extra_offset = elem_to_loc(idx, shape, strides, ndim);
|
||||
in += extra_offset + lid_x * N_READS;
|
||||
|
||||
// The reduction is accumulated here
|
||||
U total_val = Op::init;
|
||||
|
||||
// Loop over the reduction size within thread group
|
||||
int r = 0;
|
||||
for (; r < (int)ceildiv(reduction_size, N_READS*lsize_x) - 1; r++) {
|
||||
T vals[N_READS];
|
||||
for(int i = 0; i < N_READS; i++) {
|
||||
vals[i] = in[i];
|
||||
}
|
||||
for(int i = 0; i < N_READS; i++) {
|
||||
total_val = op(static_cast<U>(vals[i]), total_val);
|
||||
}
|
||||
|
||||
in += lsize_x * N_READS;
|
||||
}
|
||||
|
||||
// Separate case for the last set as we close the reduction size
|
||||
size_t reduction_index = (lid_x + (size_t)lsize_x * r) * N_READS;
|
||||
if(reduction_index < reduction_size) {
|
||||
int max_reads = reduction_size - reduction_index;
|
||||
|
||||
T vals[N_READS];
|
||||
for(int i = 0; i < N_READS; i++) {
|
||||
int idx = min(i, max_reads - 1);
|
||||
vals[i] = static_cast<U>(in[idx]);
|
||||
}
|
||||
for(int i = 0; i < N_READS; i++) {
|
||||
T val = i < max_reads ? vals[i] : Op::init;
|
||||
total_val = op(static_cast<U>(val), total_val);
|
||||
}
|
||||
}
|
||||
|
||||
return total_val;
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
[[kernel]] void row_reduce_general(
|
||||
const device T *in [[buffer(0)]],
|
||||
device mlx_atomic<U> *out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& out_size [[buffer(3)]],
|
||||
const constant int* shape [[buffer(4)]],
|
||||
const constant size_t* strides [[buffer(5)]],
|
||||
const constant int& ndim [[buffer(6)]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 lsize [[threads_per_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
Op op;
|
||||
threadgroup U local_vals[simd_size];
|
||||
|
||||
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(in, reduction_size, out_size, shape, strides, ndim, lsize.x, lid.x, tid.xy);
|
||||
|
||||
total_val = op.simd_reduce(total_val);
|
||||
|
||||
// Prepare next level
|
||||
if (simd_lane_id == 0) {
|
||||
local_vals[simd_group_id] = total_val;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Reduction within thread group
|
||||
// Only needed if multiple simd groups
|
||||
if(reduction_size > simd_size) {
|
||||
total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init;
|
||||
total_val = op.simd_reduce(total_val);
|
||||
}
|
||||
// Update output
|
||||
if (lid.x == 0) {
|
||||
op.atomic_update(out, total_val, tid.x);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
[[kernel]] void row_reduce_general_no_atomics(
|
||||
const device T *in [[buffer(0)]],
|
||||
device U *out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& out_size [[buffer(3)]],
|
||||
const constant int* shape [[buffer(4)]],
|
||||
const constant size_t* strides [[buffer(5)]],
|
||||
const constant int& ndim [[buffer(6)]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 lsize [[threads_per_threadgroup]],
|
||||
uint3 gsize [[threads_per_grid]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
Op op;
|
||||
|
||||
threadgroup U local_vals[simd_size];
|
||||
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(in, reduction_size, out_size, shape, strides, ndim, lsize.x, lid.x, tid.xy);
|
||||
|
||||
// Reduction within simd group - simd_add isn't supported for int64 types
|
||||
for (uint16_t i = simd_size/2; i > 0; i /= 2) {
|
||||
total_val = op(total_val, simd_shuffle_down(total_val, i));
|
||||
}
|
||||
|
||||
// Prepare next level
|
||||
if (simd_lane_id == 0) {
|
||||
local_vals[simd_group_id] = total_val;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Reduction within thread group
|
||||
// Only needed if thread group has multiple simd groups
|
||||
if(ceildiv(reduction_size, N_READS) > simd_size) {
|
||||
total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init;
|
||||
for (uint16_t i = simd_size/2; i > 0; i /= 2) {
|
||||
total_val = op(total_val, simd_shuffle_down(total_val, i));
|
||||
}
|
||||
}
|
||||
// Write row reduce output for threadgroup with 1st thread in thread group
|
||||
if (lid.x == 0) {
|
||||
out[(ceildiv(gsize.y, lsize.y) * tid.x) + tid.y] = total_val;
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_row_reduce_general(name, itype, otype, op) \
|
||||
template [[host_name("row_reduce_general_" #name)]] \
|
||||
[[kernel]] void row_reduce_general<itype, otype, op>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device mlx_atomic<otype> *out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& out_size [[buffer(3)]], \
|
||||
const constant int* shape [[buffer(4)]], \
|
||||
const constant size_t* strides [[buffer(5)]], \
|
||||
const constant int& ndim [[buffer(6)]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 lsize [[threads_per_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \
|
||||
template [[host_name("row_reduce_general_no_atomics_" #name)]] \
|
||||
[[kernel]] void row_reduce_general_no_atomics<itype, otype, op>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device otype *out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& out_size [[buffer(3)]], \
|
||||
const constant int* shape [[buffer(4)]], \
|
||||
const constant size_t* strides [[buffer(5)]], \
|
||||
const constant int& ndim [[buffer(6)]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 lsize [[threads_per_threadgroup]], \
|
||||
uint3 gsize [[threads_per_grid]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Column reduce
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
inline U _contiguous_strided_reduce(
|
||||
const device T *in,
|
||||
threadgroup U *local_data,
|
||||
uint in_idx,
|
||||
uint reduction_size,
|
||||
uint reduction_stride,
|
||||
uint2 tid,
|
||||
uint2 lid,
|
||||
uint2 lsize) {
|
||||
|
||||
Op op;
|
||||
U total_val = Op::init;
|
||||
|
||||
uint base_offset = (tid.y * lsize.y + lid.y) * N_READS;
|
||||
for(uint r = 0; r < N_READS && (base_offset + r) < reduction_size; r++) {
|
||||
uint offset = base_offset + r;
|
||||
total_val = op(static_cast<U>(total_val), in[in_idx + offset * reduction_stride]);
|
||||
}
|
||||
local_data[lsize.y * lid.x + lid.y] = total_val;
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
U val = Op::init;
|
||||
if(lid.y == 0) {
|
||||
// Perform reduction across columns in thread group
|
||||
for(uint i = 0; i < lsize.y; i++) {
|
||||
val = op(val, local_data[lsize.y * lid.x + i]);
|
||||
}
|
||||
}
|
||||
|
||||
return val;
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void col_reduce_general(
|
||||
const device T *in [[buffer(0)]],
|
||||
device mlx_atomic<U> *out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
threadgroup U *local_data [[threadgroup(0)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 lsize [[threads_per_threadgroup]]) {
|
||||
auto out_idx = tid.x * lsize.x + lid.x;
|
||||
auto in_idx = elem_to_loc(
|
||||
out_idx + tid.z * out_size,
|
||||
shape,
|
||||
strides,
|
||||
ndim
|
||||
);
|
||||
|
||||
Op op;
|
||||
if(out_idx < out_size) {
|
||||
U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
|
||||
in,
|
||||
local_data,
|
||||
in_idx,
|
||||
reduction_size,
|
||||
reduction_stride,
|
||||
tid.xy,
|
||||
lid.xy,
|
||||
lsize.xy);
|
||||
|
||||
// Write out reduction results generated by threadgroups working on specific output element, contiguously.
|
||||
if (lid.y == 0) {
|
||||
op.atomic_update(out, val, out_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void col_reduce_general_no_atomics(
|
||||
const device T *in [[buffer(0)]],
|
||||
device U *out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
threadgroup U *local_data [[threadgroup(0)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 gid [[thread_position_in_grid]],
|
||||
uint3 lsize [[threads_per_threadgroup]],
|
||||
uint3 gsize [[threads_per_grid]]) {
|
||||
auto out_idx = tid.x * lsize.x + lid.x;
|
||||
auto in_idx = elem_to_loc(
|
||||
out_idx + tid.z * out_size,
|
||||
shape,
|
||||
strides,
|
||||
ndim
|
||||
);
|
||||
|
||||
if(out_idx < out_size) {
|
||||
U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
|
||||
in,
|
||||
local_data,
|
||||
in_idx,
|
||||
reduction_size,
|
||||
reduction_stride,
|
||||
tid.xy,
|
||||
lid.xy,
|
||||
lsize.xy);
|
||||
|
||||
// Write out reduction results generated by threadgroups working on specific output element, contiguously.
|
||||
if (lid.y == 0) {
|
||||
uint tgsize_y = ceildiv(gsize.y, lsize.y);
|
||||
uint tgsize_z = ceildiv(gsize.z, lsize.z);
|
||||
out[tgsize_y * tgsize_z * gid.x + tgsize_y * tid.z + tid.y] = val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_col_reduce_general(name, itype, otype, op) \
|
||||
template [[host_name("col_reduce_general_" #name)]] \
|
||||
[[kernel]] void col_reduce_general<itype, otype, op>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device mlx_atomic<otype> *out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
threadgroup otype *local_data [[threadgroup(0)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 lsize [[threads_per_threadgroup]]);
|
||||
|
||||
#define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \
|
||||
template [[host_name("col_reduce_general_no_atomics_" #name)]] \
|
||||
[[kernel]] void col_reduce_general_no_atomics<itype, otype, op>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device otype *out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
threadgroup otype *local_data [[threadgroup(0)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 gid [[thread_position_in_grid]], \
|
||||
uint3 lsize [[threads_per_threadgroup]], \
|
||||
uint3 gsize [[threads_per_grid]]);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Instantiations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_reduce(name, itype, otype, op) \
|
||||
instantiate_all_reduce(name, itype, otype, op) \
|
||||
instantiate_row_reduce_general(name, itype, otype, op) \
|
||||
instantiate_col_reduce_general(name, itype, otype, op)
|
||||
|
||||
#define instantiate_reduce_no_atomics(name, itype, otype, op) \
|
||||
instantiate_all_reduce_no_atomics(name, itype, otype, op) \
|
||||
instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \
|
||||
instantiate_col_reduce_general_no_atomics(name, itype, otype, op)
|
||||
|
||||
#define instantiate_same_reduce_no_atomics(name, tname, type, op) \
|
||||
instantiate_init_reduce(name ##tname, type, op<type>) \
|
||||
instantiate_reduce_no_atomics(name ##tname, type, type, op<type>)
|
||||
|
||||
#define instantiate_same_reduce(name, tname, type, op) \
|
||||
instantiate_init_reduce(name ##tname, type, op<type>) \
|
||||
instantiate_reduce(name ##tname, type, type, op<type>)
|
||||
|
||||
#define instantiate_reduce_from_types_helper(name, tname, itype, otype, op) \
|
||||
instantiate_reduce(name ##tname, itype, otype, op)
|
||||
|
||||
#define instantiate_reduce_from_types(name, otype, op) \
|
||||
instantiate_reduce_from_types_helper(name, bool_, bool, otype, op) \
|
||||
instantiate_reduce_from_types_helper(name, uint8, uint8_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper(name, uint16, uint16_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper(name, uint32, uint32_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper(name, int8, int8_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper(name, int16, int16_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper(name, int32, int32_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper(name, int64, int64_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper(name, float16, half, otype, op) \
|
||||
instantiate_reduce_from_types_helper(name, float32, float, otype, op) \
|
||||
instantiate_reduce_from_types_helper(name, bfloat16, bfloat16_t, otype, op)
|
||||
|
||||
// special case bool with larger output type
|
||||
instantiate_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
instantiate_same_reduce(sum, uint8, uint8_t, Sum)
|
||||
instantiate_same_reduce(sum, uint16, uint16_t, Sum)
|
||||
instantiate_same_reduce(sum, uint32, uint32_t, Sum)
|
||||
instantiate_same_reduce(sum, int8, int8_t, Sum)
|
||||
instantiate_same_reduce(sum, int16, int16_t, Sum)
|
||||
instantiate_same_reduce(sum, int32, int32_t, Sum)
|
||||
instantiate_same_reduce(sum, float16, half, Sum)
|
||||
instantiate_same_reduce(sum, float32, float, Sum)
|
||||
|
||||
instantiate_same_reduce_no_atomics(sum, int64, int64_t, Sum)
|
||||
instantiate_same_reduce_no_atomics(sum, uint64, uint64_t, Sum)
|
||||
|
||||
instantiate_same_reduce(prod, uint8, uint8_t, Prod)
|
||||
instantiate_same_reduce(prod, uint16, uint16_t, Prod)
|
||||
instantiate_same_reduce(prod, uint32, uint32_t, Prod)
|
||||
instantiate_same_reduce(prod, int8, int8_t, Prod)
|
||||
instantiate_same_reduce(prod, int16, int16_t, Prod)
|
||||
instantiate_same_reduce(prod, int32, int32_t, Prod)
|
||||
instantiate_same_reduce(prod, float16, half, Prod)
|
||||
instantiate_same_reduce(prod, float32, float, Prod)
|
||||
|
||||
instantiate_same_reduce_no_atomics(prod, int64, int64_t, Prod)
|
||||
instantiate_same_reduce_no_atomics(prod, uint64, uint64_t, Prod)
|
||||
|
||||
instantiate_same_reduce(sum, bfloat16, bfloat16_t, Sum)
|
||||
instantiate_same_reduce(prod, bfloat16, bfloat16_t, Prod)
|
||||
|
||||
instantiate_init_reduce(andbool_, bool, And)
|
||||
instantiate_reduce_from_types(and, bool, And)
|
||||
|
||||
instantiate_init_reduce(orbool_, bool, Or)
|
||||
instantiate_reduce_from_types(or, bool, Or)
|
||||
|
||||
// Compiler segfaulted with the names "min" or "max" ...
|
||||
instantiate_same_reduce(min_, uint8, uint8_t, Min)
|
||||
instantiate_same_reduce(min_, uint16, uint16_t, Min)
|
||||
instantiate_same_reduce(min_, uint32, uint32_t, Min)
|
||||
instantiate_same_reduce(min_, int8, int8_t, Min)
|
||||
instantiate_same_reduce(min_, int16, int16_t, Min)
|
||||
instantiate_same_reduce(min_, int32, int32_t, Min)
|
||||
instantiate_same_reduce(min_, float16, half, Min)
|
||||
instantiate_same_reduce(min_, float32, float, Min)
|
||||
|
||||
instantiate_same_reduce_no_atomics(min_, int64, int64_t, Min)
|
||||
instantiate_same_reduce_no_atomics(min_, uint64, uint64_t, Min)
|
||||
|
||||
instantiate_same_reduce(max_, uint8, uint8_t, Max)
|
||||
instantiate_same_reduce(max_, uint16, uint16_t, Max)
|
||||
instantiate_same_reduce(max_, uint32, uint32_t, Max)
|
||||
instantiate_same_reduce(max_, int8, int8_t, Max)
|
||||
instantiate_same_reduce(max_, int16, int16_t, Max)
|
||||
instantiate_same_reduce(max_, int32, int32_t, Max)
|
||||
instantiate_same_reduce(max_, float16, half, Max)
|
||||
instantiate_same_reduce(max_, float32, float, Max)
|
||||
|
||||
instantiate_same_reduce_no_atomics(max_, int64, int64_t, Max)
|
||||
instantiate_same_reduce_no_atomics(max_, uint64, uint64_t, Max)
|
||||
|
||||
instantiate_same_reduce(min_, bfloat16, bfloat16_t, Min)
|
||||
instantiate_same_reduce(max_, bfloat16, bfloat16_t, Max)
|
185
mlx/backend/metal/kernels/reduction/kernels/reduce_all.metal
Normal file
185
mlx/backend/metal/kernels/reduction/kernels/reduce_all.metal
Normal file
@ -0,0 +1,185 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/reduction/utils.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// All reduce helper
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
METAL_FUNC U per_thread_all_reduce(
|
||||
const device T* in,
|
||||
const device size_t& in_size,
|
||||
uint gid,
|
||||
uint grid_size) {
|
||||
Op op;
|
||||
U total_val = Op::init;
|
||||
|
||||
if (gid * N_READS < in_size) {
|
||||
in += gid * N_READS;
|
||||
|
||||
int r = 0;
|
||||
for (; r < (int)ceildiv(in_size, grid_size * N_READS) - 1; r++) {
|
||||
U vals[N_READS] = {op.init};
|
||||
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
vals[i] = static_cast<U>(in[i]);
|
||||
}
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
total_val = op(vals[i], total_val);
|
||||
}
|
||||
|
||||
in += grid_size * N_READS;
|
||||
}
|
||||
|
||||
// Separate case for the last set as we close the reduction size
|
||||
size_t curr_idx = (gid + r * (size_t)grid_size) * N_READS;
|
||||
if (curr_idx < in_size) {
|
||||
int max_reads = in_size - curr_idx;
|
||||
T vals[N_READS];
|
||||
|
||||
for (int i = 0, idx = 0; i < N_READS; i++, idx++) {
|
||||
idx = idx < max_reads ? idx : max_reads - 1;
|
||||
vals[i] = in[idx];
|
||||
}
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
U val = i < max_reads ? vals[i] : Op::init;
|
||||
total_val = op(static_cast<U>(val), total_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return total_val;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// All reduce kernel
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
// NB: This kernel assumes threads_per_threadgroup is at most
|
||||
// 1024. This way with a simd_size of 32, we are guaranteed to
|
||||
// complete the reduction in two steps of simd-level reductions.
|
||||
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
[[kernel]] void all_reduce(
|
||||
const device T *in [[buffer(0)]],
|
||||
device mlx_atomic<U> *out [[buffer(1)]],
|
||||
const device size_t& in_size [[buffer(2)]],
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint grid_size [[threads_per_grid]],
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
Op op;
|
||||
threadgroup U local_vals[simd_size];
|
||||
|
||||
U total_val = per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
|
||||
|
||||
// Reduction within simd group
|
||||
total_val = op.simd_reduce(total_val);
|
||||
if (simd_lane_id == 0) {
|
||||
local_vals[simd_group_id] = total_val;
|
||||
}
|
||||
|
||||
// Reduction within thread group
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
|
||||
total_val = op.simd_reduce(total_val);
|
||||
|
||||
// Reduction across threadgroups
|
||||
if (lid == 0) {
|
||||
op.atomic_update(out, total_val);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
[[kernel]] void all_reduce_no_atomics(
|
||||
const device T *in [[buffer(0)]],
|
||||
device U *out [[buffer(1)]],
|
||||
const device size_t& in_size [[buffer(2)]],
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint grid_size [[threads_per_grid]],
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint thread_group_id [[threadgroup_position_in_grid]]) {
|
||||
|
||||
Op op;
|
||||
threadgroup U local_vals[simd_size];
|
||||
|
||||
U total_val = per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
|
||||
|
||||
// Reduction within simd group (simd_add isn't supported for uint64/int64 types)
|
||||
for (uint16_t lane_offset = simd_size/2; lane_offset > 0; lane_offset /= 2) {
|
||||
total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
|
||||
}
|
||||
// Write simd group reduction results to local memory
|
||||
if (simd_lane_id == 0) {
|
||||
local_vals[simd_group_id] = total_val;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Reduction of simdgroup reduction results within threadgroup.
|
||||
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
|
||||
for (uint16_t lane_offset = simd_size/2; lane_offset > 0; lane_offset /= 2) {
|
||||
total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
|
||||
}
|
||||
|
||||
// Reduction across threadgroups
|
||||
if (lid == 0) {
|
||||
out[thread_group_id] = total_val;
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_all_reduce(name, itype, otype, op) \
|
||||
template [[host_name("all_reduce_" #name)]] \
|
||||
[[kernel]] void all_reduce<itype, otype, op>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device mlx_atomic<otype> *out [[buffer(1)]], \
|
||||
const device size_t& in_size [[buffer(2)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint grid_size [[threads_per_grid]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_all_reduce_no_atomics(name, itype, otype, op) \
|
||||
template [[host_name("all_reduce_no_atomics_" #name)]] \
|
||||
[[kernel]] void all_reduce_no_atomics<itype, otype, op>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device otype *out [[buffer(1)]], \
|
||||
const device size_t& in_size [[buffer(2)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint grid_size [[threads_per_grid]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint thread_group_id [[threadgroup_position_in_grid]]);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Instantiations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_same_all_reduce_helper(name, tname, type, op) \
|
||||
instantiate_all_reduce(name ##tname, type, type, op<type>)
|
||||
|
||||
#define instantiate_same_all_reduce_na_helper(name, tname, type, op) \
|
||||
instantiate_all_reduce_no_atomics(name ##tname, type, type, op<type>)
|
||||
|
||||
instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_same_all_reduce_na_helper, instantiate_reduce_helper_64b)
|
||||
|
||||
instantiate_reduce_from_types(instantiate_all_reduce, and, bool, And)
|
||||
instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or)
|
||||
|
||||
// special case bool with larger output type
|
||||
instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
184
mlx/backend/metal/kernels/reduction/kernels/reduce_col.metal
Normal file
184
mlx/backend/metal/kernels/reduction/kernels/reduce_col.metal
Normal file
@ -0,0 +1,184 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/reduction/utils.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Column reduce helper
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
METAL_FUNC U _contiguous_strided_reduce(
|
||||
const device T* in,
|
||||
threadgroup U* local_data,
|
||||
uint in_idx,
|
||||
uint reduction_size,
|
||||
uint reduction_stride,
|
||||
uint2 tid,
|
||||
uint2 lid,
|
||||
uint2 lsize) {
|
||||
Op op;
|
||||
U total_val = Op::init;
|
||||
|
||||
uint base_offset = (tid.y * lsize.y + lid.y) * N_READS;
|
||||
for (uint r = 0; r < N_READS && (base_offset + r) < reduction_size; r++) {
|
||||
uint offset = base_offset + r;
|
||||
total_val =
|
||||
op(static_cast<U>(total_val), in[in_idx + offset * reduction_stride]);
|
||||
}
|
||||
local_data[lsize.y * lid.x + lid.y] = total_val;
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
U val = Op::init;
|
||||
if (lid.y == 0) {
|
||||
// Perform reduction across columns in thread group
|
||||
for (uint i = 0; i < lsize.y; i++) {
|
||||
val = op(val, local_data[lsize.y * lid.x + i]);
|
||||
}
|
||||
}
|
||||
|
||||
return val;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Column reduce kernel
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void col_reduce_general(
|
||||
const device T *in [[buffer(0)]],
|
||||
device mlx_atomic<U> *out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
threadgroup U *local_data [[threadgroup(0)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 lsize [[threads_per_threadgroup]]) {
|
||||
auto out_idx = tid.x * lsize.x + lid.x;
|
||||
auto in_idx = elem_to_loc(
|
||||
out_idx + tid.z * out_size,
|
||||
shape,
|
||||
strides,
|
||||
ndim
|
||||
);
|
||||
|
||||
Op op;
|
||||
if(out_idx < out_size) {
|
||||
U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
|
||||
in,
|
||||
local_data,
|
||||
in_idx,
|
||||
reduction_size,
|
||||
reduction_stride,
|
||||
tid.xy,
|
||||
lid.xy,
|
||||
lsize.xy);
|
||||
|
||||
// Write out reduction results generated by threadgroups working on specific output element, contiguously.
|
||||
if (lid.y == 0) {
|
||||
op.atomic_update(out, val, out_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void col_reduce_general_no_atomics(
|
||||
const device T *in [[buffer(0)]],
|
||||
device U *out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
threadgroup U *local_data [[threadgroup(0)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 gid [[thread_position_in_grid]],
|
||||
uint3 lsize [[threads_per_threadgroup]],
|
||||
uint3 gsize [[threads_per_grid]]) {
|
||||
auto out_idx = tid.x * lsize.x + lid.x;
|
||||
auto in_idx = elem_to_loc(
|
||||
out_idx + tid.z * out_size,
|
||||
shape,
|
||||
strides,
|
||||
ndim
|
||||
);
|
||||
|
||||
if(out_idx < out_size) {
|
||||
U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
|
||||
in,
|
||||
local_data,
|
||||
in_idx,
|
||||
reduction_size,
|
||||
reduction_stride,
|
||||
tid.xy,
|
||||
lid.xy,
|
||||
lsize.xy);
|
||||
|
||||
// Write out reduction results generated by threadgroups working on specific output element, contiguously.
|
||||
if (lid.y == 0) {
|
||||
uint tgsize_y = ceildiv(gsize.y, lsize.y);
|
||||
uint tgsize_z = ceildiv(gsize.z, lsize.z);
|
||||
out[tgsize_y * tgsize_z * gid.x + tgsize_y * tid.z + tid.y] = val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_col_reduce_general(name, itype, otype, op) \
|
||||
template [[host_name("col_reduce_general_" #name)]] \
|
||||
[[kernel]] void col_reduce_general<itype, otype, op>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device mlx_atomic<otype> *out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
threadgroup otype *local_data [[threadgroup(0)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 lsize [[threads_per_threadgroup]]);
|
||||
|
||||
#define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \
|
||||
template [[host_name("col_reduce_general_no_atomics_" #name)]] \
|
||||
[[kernel]] void col_reduce_general_no_atomics<itype, otype, op>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device otype *out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
threadgroup otype *local_data [[threadgroup(0)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 gid [[thread_position_in_grid]], \
|
||||
uint3 lsize [[threads_per_threadgroup]], \
|
||||
uint3 gsize [[threads_per_grid]]);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Instantiations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_same_col_reduce_helper(name, tname, type, op) \
|
||||
instantiate_col_reduce_general(name ##tname, type, type, op<type>)
|
||||
|
||||
#define instantiate_same_col_reduce_na_helper(name, tname, type, op) \
|
||||
instantiate_col_reduce_general_no_atomics(name ##tname, type, type, op<type>)
|
||||
|
||||
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_same_col_reduce_na_helper, instantiate_reduce_helper_64b)
|
||||
|
||||
instantiate_col_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_general, and, bool, And)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or)
|
@ -0,0 +1,33 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/reduction/utils.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Reduce init
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename Op>
|
||||
[[kernel]] void init_reduce(
|
||||
device T *out [[buffer(0)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
out[tid] = Op::init;
|
||||
}
|
||||
|
||||
#define instantiate_init_reduce(name, otype, op) \
|
||||
template [[host_name("i" #name)]] \
|
||||
[[kernel]] void init_reduce<otype, op>( \
|
||||
device otype *out [[buffer(1)]], \
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_init_reduce_helper(name, tname, type, op) \
|
||||
instantiate_init_reduce(name ##tname, type, op<type>)
|
||||
|
||||
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_64b)
|
||||
|
||||
instantiate_init_reduce(andbool_, bool, And)
|
||||
instantiate_init_reduce(orbool_, bool, Or)
|
369
mlx/backend/metal/kernels/reduction/kernels/reduce_row.metal
Normal file
369
mlx/backend/metal/kernels/reduction/kernels/reduce_row.metal
Normal file
@ -0,0 +1,369 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/reduction/utils.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Small row reductions
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Each thread reduces for one output
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void row_reduce_general_small(
|
||||
const device T *in [[buffer(0)]],
|
||||
device U *out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& out_size [[buffer(3)]],
|
||||
const constant size_t& non_row_reductions [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
uint lid [[thread_position_in_grid]]) {
|
||||
|
||||
Op op;
|
||||
|
||||
uint out_idx = lid;
|
||||
|
||||
if(out_idx >= out_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
U total_val = Op::init;
|
||||
|
||||
for(short r = 0; r < short(non_row_reductions); r++) {
|
||||
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
|
||||
const device T * in_row = in + in_idx;
|
||||
|
||||
for(short i = 0; i < short(reduction_size); i++) {
|
||||
total_val = op(static_cast<U>(in_row[i]), total_val);
|
||||
}
|
||||
}
|
||||
|
||||
out[out_idx] = total_val;
|
||||
}
|
||||
|
||||
// Each simdgroup reduces for one output
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void row_reduce_general_med(
|
||||
const device T *in [[buffer(0)]],
|
||||
device U *out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& out_size [[buffer(3)]],
|
||||
const constant size_t& non_row_reductions [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
uint tid [[threadgroup_position_in_grid]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_per_group [[dispatch_simdgroups_per_threadgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
Op op;
|
||||
|
||||
uint out_idx = simd_per_group * tid + simd_group_id;
|
||||
|
||||
if(out_idx >= out_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
U total_val = Op::init;
|
||||
|
||||
if(short(non_row_reductions) == 1) {
|
||||
uint in_idx = elem_to_loc(out_idx, shape, strides, ndim);
|
||||
const device T * in_row = in + in_idx;
|
||||
|
||||
for(short i = simd_lane_id; i < short(reduction_size); i += 32) {
|
||||
total_val = op(static_cast<U>(in_row[i]), total_val);
|
||||
}
|
||||
}
|
||||
|
||||
else if (short(non_row_reductions) >= 32) {
|
||||
|
||||
for(short r = simd_lane_id; r < short(non_row_reductions); r+=32) {
|
||||
|
||||
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
|
||||
const device T * in_row = in + in_idx;
|
||||
|
||||
for(short i = 0; i < short(reduction_size); i++) {
|
||||
total_val = op(static_cast<U>(in_row[i]), total_val);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
else {
|
||||
|
||||
const short n_reductions = short(reduction_size) * short(non_row_reductions);
|
||||
const short reductions_per_thread = (n_reductions + simd_size - 1) / simd_size;
|
||||
|
||||
const short r_st = simd_lane_id / reductions_per_thread;
|
||||
const short r_ed = short(non_row_reductions);
|
||||
const short r_jump = simd_size / reductions_per_thread;
|
||||
|
||||
const short i_st = simd_lane_id % reductions_per_thread;
|
||||
const short i_ed = short(reduction_size);
|
||||
const short i_jump = reductions_per_thread;
|
||||
|
||||
for(short r = r_st; r < r_ed; r += r_jump) {
|
||||
|
||||
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
|
||||
const device T * in_row = in + in_idx;
|
||||
|
||||
for(short i = i_st; i < i_ed; i += i_jump) {
|
||||
total_val = op(static_cast<U>(in_row[i]), total_val);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
total_val = op.simd_reduce(total_val);
|
||||
|
||||
if(simd_lane_id == 0) {
|
||||
out[out_idx] = total_val;
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_row_reduce_small(name, itype, otype, op) \
|
||||
template[[host_name("row_reduce_general_small_" #name)]] \
|
||||
[[kernel]] void row_reduce_general_small<itype, otype, op>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device otype *out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& out_size [[buffer(3)]], \
|
||||
const constant size_t& non_row_reductions [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
uint lid [[thread_position_in_grid]]); \
|
||||
template[[host_name("row_reduce_general_med_" #name)]] \
|
||||
[[kernel]] void row_reduce_general_med<itype, otype, op>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device otype *out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& out_size [[buffer(3)]], \
|
||||
const constant size_t& non_row_reductions [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
uint tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_per_group [[dispatch_simdgroups_per_threadgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Large row reductions
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
METAL_FUNC U per_thread_row_reduce(
|
||||
const device T* in,
|
||||
const constant size_t& reduction_size,
|
||||
const constant size_t& out_size,
|
||||
const constant int* shape,
|
||||
const constant size_t* strides,
|
||||
const constant int& ndim,
|
||||
uint lsize_x,
|
||||
uint lid_x,
|
||||
uint2 tid) {
|
||||
Op op;
|
||||
|
||||
// Each threadgroup handles 1 reduction
|
||||
// TODO: Specializing elem_to_loc would be slightly faster
|
||||
int idx = tid.y * out_size + tid.x;
|
||||
int extra_offset = elem_to_loc(idx, shape, strides, ndim);
|
||||
in += extra_offset + lid_x * N_READS;
|
||||
|
||||
// The reduction is accumulated here
|
||||
U total_val = Op::init;
|
||||
|
||||
// Loop over the reduction size within thread group
|
||||
int r = 0;
|
||||
for (; r < (int)ceildiv(reduction_size, N_READS * lsize_x) - 1; r++) {
|
||||
T vals[N_READS];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
vals[i] = in[i];
|
||||
}
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
total_val = op(static_cast<U>(vals[i]), total_val);
|
||||
}
|
||||
|
||||
in += lsize_x * N_READS;
|
||||
}
|
||||
|
||||
// Separate case for the last set as we close the reduction size
|
||||
size_t reduction_index = (lid_x + (size_t)lsize_x * r) * N_READS;
|
||||
if (reduction_index < reduction_size) {
|
||||
int max_reads = reduction_size - reduction_index;
|
||||
|
||||
T vals[N_READS];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
int idx = min(i, max_reads - 1);
|
||||
vals[i] = static_cast<U>(in[idx]);
|
||||
}
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
T val = i < max_reads ? vals[i] : Op::init;
|
||||
total_val = op(static_cast<U>(val), total_val);
|
||||
}
|
||||
}
|
||||
|
||||
return total_val;
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
[[kernel]] void row_reduce_general(
|
||||
const device T *in [[buffer(0)]],
|
||||
device mlx_atomic<U> *out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& out_size [[buffer(3)]],
|
||||
const constant size_t& non_row_reductions [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 lsize [[threads_per_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
(void)non_row_reductions;
|
||||
|
||||
Op op;
|
||||
threadgroup U local_vals[simd_size];
|
||||
|
||||
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(in, reduction_size, out_size, shape, strides, ndim, lsize.x, lid.x, tid.xy);
|
||||
|
||||
total_val = op.simd_reduce(total_val);
|
||||
|
||||
// Prepare next level
|
||||
if (simd_lane_id == 0) {
|
||||
local_vals[simd_group_id] = total_val;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Reduction within thread group
|
||||
// Only needed if multiple simd groups
|
||||
if(reduction_size > simd_size) {
|
||||
total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init;
|
||||
total_val = op.simd_reduce(total_val);
|
||||
}
|
||||
// Update output
|
||||
if (lid.x == 0) {
|
||||
op.atomic_update(out, total_val, tid.x);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
[[kernel]] void row_reduce_general_no_atomics(
|
||||
const device T *in [[buffer(0)]],
|
||||
device U *out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& out_size [[buffer(3)]],
|
||||
const constant size_t& non_row_reductions [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 lsize [[threads_per_threadgroup]],
|
||||
uint3 gsize [[threads_per_grid]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
(void)non_row_reductions;
|
||||
|
||||
Op op;
|
||||
|
||||
threadgroup U local_vals[simd_size];
|
||||
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(in, reduction_size, out_size, shape, strides, ndim, lsize.x, lid.x, tid.xy);
|
||||
|
||||
// Reduction within simd group - simd_add isn't supported for int64 types
|
||||
for (uint16_t i = simd_size/2; i > 0; i /= 2) {
|
||||
total_val = op(total_val, simd_shuffle_down(total_val, i));
|
||||
}
|
||||
|
||||
// Prepare next level
|
||||
if (simd_lane_id == 0) {
|
||||
local_vals[simd_group_id] = total_val;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Reduction within thread group
|
||||
// Only needed if thread group has multiple simd groups
|
||||
if(ceildiv(reduction_size, N_READS) > simd_size) {
|
||||
total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init;
|
||||
for (uint16_t i = simd_size/2; i > 0; i /= 2) {
|
||||
total_val = op(total_val, simd_shuffle_down(total_val, i));
|
||||
}
|
||||
}
|
||||
// Write row reduce output for threadgroup with 1st thread in thread group
|
||||
if (lid.x == 0) {
|
||||
out[(ceildiv(gsize.y, lsize.y) * tid.x) + tid.y] = total_val;
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_row_reduce_general(name, itype, otype, op) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op) \
|
||||
template [[host_name("row_reduce_general_" #name)]] \
|
||||
[[kernel]] void row_reduce_general<itype, otype, op>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device mlx_atomic<otype> *out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& out_size [[buffer(3)]], \
|
||||
const constant size_t& non_row_reductions [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 lsize [[threads_per_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op) \
|
||||
template [[host_name("row_reduce_general_no_atomics_" #name)]] \
|
||||
[[kernel]] void row_reduce_general_no_atomics<itype, otype, op>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device otype *out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& out_size [[buffer(3)]], \
|
||||
const constant size_t& non_row_reductions [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 lsize [[threads_per_threadgroup]], \
|
||||
uint3 gsize [[threads_per_grid]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Instantiations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_same_row_reduce_helper(name, tname, type, op) \
|
||||
instantiate_row_reduce_general(name ##tname, type, type, op<type>)
|
||||
|
||||
#define instantiate_same_row_reduce_na_helper(name, tname, type, op) \
|
||||
instantiate_row_reduce_general_no_atomics(name ##tname, type, type, op<type>)
|
||||
|
||||
instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_same_row_reduce_na_helper, instantiate_reduce_helper_64b)
|
||||
|
||||
|
||||
instantiate_reduce_from_types(instantiate_row_reduce_general, and, bool, And)
|
||||
instantiate_reduce_from_types(instantiate_row_reduce_general, or, bool, Or)
|
||||
|
||||
instantiate_row_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
71
mlx/backend/metal/kernels/reduction/reduce_inst.h
Normal file
71
mlx/backend/metal/kernels/reduction/reduce_inst.h
Normal file
@ -0,0 +1,71 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_atomic>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
||||
|
||||
#define instantiate_reduce_helper_floats(inst_f, name, op) \
|
||||
inst_f(name, float16, half, op) inst_f(name, float32, float, op) \
|
||||
inst_f(name, bfloat16, bfloat16_t, op)
|
||||
|
||||
#define instantiate_reduce_helper_uints(inst_f, name, op) \
|
||||
inst_f(name, uint8, uint8_t, op) inst_f(name, uint16, uint16_t, op) \
|
||||
inst_f(name, uint32, uint32_t, op)
|
||||
|
||||
#define instantiate_reduce_helper_ints(inst_f, name, op) \
|
||||
inst_f(name, int8, int8_t, op) inst_f(name, int16, int16_t, op) \
|
||||
inst_f(name, int32, int32_t, op)
|
||||
|
||||
#define instantiate_reduce_helper_64b(inst_f, name, op) \
|
||||
inst_f(name, int64, int64_t, op) inst_f(name, uint64, uint64_t, op)
|
||||
|
||||
#define instantiate_reduce_helper_types(inst_f, name, op) \
|
||||
instantiate_reduce_helper_floats(inst_f, name, op) \
|
||||
instantiate_reduce_helper_uints(inst_f, name, op) \
|
||||
instantiate_reduce_helper_ints(inst_f, name, op)
|
||||
|
||||
#define instantiate_reduce_ops(inst_f, type_f) \
|
||||
type_f(inst_f, sum, Sum) type_f(inst_f, prod, Prod) \
|
||||
type_f(inst_f, min_, Min) type_f(inst_f, max_, Max)
|
||||
|
||||
// Special case for bool reductions
|
||||
#define instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, tname, itype, otype, op) \
|
||||
inst_f(name##tname, itype, otype, op)
|
||||
|
||||
#define instantiate_reduce_from_types(inst_f, name, otype, op) \
|
||||
instantiate_reduce_from_types_helper(inst_f, name, bool_, bool, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, uint8, uint8_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, uint16, uint16_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, uint32, uint32_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, int8, int8_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, int16, int16_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, int32, int32_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, int64, int64_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, float16, half, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, \
|
||||
name, \
|
||||
float32, \
|
||||
float, \
|
||||
otype, \
|
||||
op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, \
|
||||
name, \
|
||||
bfloat16, \
|
||||
bfloat16_t, \
|
||||
otype, \
|
||||
op)
|
14
mlx/backend/metal/kernels/reduction/utils.h
Normal file
14
mlx/backend/metal/kernels/reduction/utils.h
Normal file
@ -0,0 +1,14 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_atomic>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
||||
|
||||
static constant constexpr const uint8_t simd_size = 32;
|
@ -4,7 +4,7 @@
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/indexing.h"
|
||||
#include "mlx/backend/metal/kernels/reduce.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
@ -130,15 +130,8 @@ void row_reduce_general_dispatch(
|
||||
const Stream& s) {
|
||||
Dtype out_dtype = out.dtype();
|
||||
bool is_out_64b_int = is_64b_int(out_dtype);
|
||||
auto kernel = (is_out_64b_int)
|
||||
? d.get_kernel(
|
||||
"row_reduce_general_no_atomics_" + op_name + type_to_name(in))
|
||||
: d.get_kernel("row_reduce_general_" + op_name + type_to_name(in));
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Prepare the arguments for the kernel
|
||||
int n_reads = REDUCE_N_READS;
|
||||
size_t reduction_size = plan.shape.back();
|
||||
auto shape = plan.shape;
|
||||
auto strides = plan.strides;
|
||||
@ -160,32 +153,72 @@ void row_reduce_general_dispatch(
|
||||
}
|
||||
int ndim = shape.size();
|
||||
|
||||
// Each thread group is responsible for 1 output
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
thread_group_size =
|
||||
std::min((reduction_size + n_reads - 1) / n_reads, thread_group_size);
|
||||
// Determine dispatch kernel
|
||||
std::ostringstream kname;
|
||||
|
||||
// Align thread group size with simd_size
|
||||
uint simd_size = kernel->threadExecutionWidth();
|
||||
thread_group_size =
|
||||
(thread_group_size + simd_size - 1) / simd_size * simd_size;
|
||||
assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup());
|
||||
bool is_small = non_row_reductions * reduction_size < 32;
|
||||
bool is_med = non_row_reductions * reduction_size <= 256;
|
||||
is_out_64b_int &= !is_small && !is_med;
|
||||
|
||||
// Launch enough thread groups for each output
|
||||
size_t n_threads = out.size() * thread_group_size;
|
||||
MTL::Size grid_dims = MTL::Size(n_threads, non_row_reductions, 1);
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
std::string small_desc = "_";
|
||||
if (is_small) {
|
||||
small_desc = "_small_";
|
||||
} else if (is_med) {
|
||||
small_desc = "_med_";
|
||||
}
|
||||
|
||||
if (is_out_64b_int == false || non_row_reductions == 1) {
|
||||
small_desc = is_out_64b_int ? "_no_atomics_" : small_desc;
|
||||
|
||||
kname << "row_reduce_general" << small_desc << op_name << type_to_name(in);
|
||||
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Get dispatch grid dims
|
||||
MTL::Size grid_dims;
|
||||
MTL::Size group_dims;
|
||||
|
||||
// Each thread handles one output
|
||||
if (is_small) {
|
||||
grid_dims = MTL::Size(out.size(), 1, 1);
|
||||
group_dims = MTL::Size(std::min(1024ul, out.size()), 1, 1);
|
||||
}
|
||||
// Each simdgroup handles one output
|
||||
else if (is_med) {
|
||||
grid_dims = MTL::Size(out.size() * 32, 1, 1);
|
||||
group_dims = MTL::Size(std::min(8ul, out.size()) * 32, 1, 1);
|
||||
}
|
||||
// Each theadgroup handles one output
|
||||
else {
|
||||
int n_reads = REDUCE_N_READS;
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
thread_group_size =
|
||||
std::min((reduction_size + n_reads - 1) / n_reads, thread_group_size);
|
||||
|
||||
// Align thread group size with simd_size
|
||||
uint simd_size = kernel->threadExecutionWidth();
|
||||
thread_group_size =
|
||||
(thread_group_size + simd_size - 1) / simd_size * simd_size;
|
||||
assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup());
|
||||
|
||||
// Launch enough thread groups for each output
|
||||
size_t n_threads = out.size() * thread_group_size;
|
||||
grid_dims = MTL::Size(n_threads, non_row_reductions, 1);
|
||||
group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
}
|
||||
|
||||
// Dispatch kernel
|
||||
if (!is_out_64b_int || non_row_reductions == 1) {
|
||||
// Set the arguments for the kernel
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4);
|
||||
compute_encoder->setBytes(&non_row_reductions, sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5);
|
||||
compute_encoder->setBytes(
|
||||
strides.data(), strides.size() * sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 6);
|
||||
strides.data(), strides.size() * sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 7);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
} else {
|
||||
@ -203,10 +236,11 @@ void row_reduce_general_dispatch(
|
||||
set_array_buffer(compute_encoder, intermediate, 1);
|
||||
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4);
|
||||
compute_encoder->setBytes(&non_row_reductions, sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5);
|
||||
compute_encoder->setBytes(
|
||||
strides.data(), strides.size() * sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 6);
|
||||
strides.data(), strides.size() * sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 7);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
// Set up second dispatch
|
||||
@ -230,24 +264,27 @@ void row_reduce_general_dispatch(
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(&non_row_reductions, sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(
|
||||
new_shape.data(), new_shape.size() * sizeof(int), 4);
|
||||
new_shape.data(), new_shape.size() * sizeof(int), 5);
|
||||
compute_encoder->setBytes(
|
||||
new_strides.data(), new_strides.size() * sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 6);
|
||||
new_strides.data(), new_strides.size() * sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 7);
|
||||
|
||||
// Each thread group is responsible for 1 output
|
||||
thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
int n_reads = REDUCE_N_READS;
|
||||
size_t thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
thread_group_size =
|
||||
std::min((reduction_size + n_reads - 1) / n_reads, thread_group_size);
|
||||
|
||||
// Align thread group size with simd_size
|
||||
uint simd_size = kernel->threadExecutionWidth();
|
||||
thread_group_size =
|
||||
(thread_group_size + simd_size - 1) / simd_size * simd_size;
|
||||
assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup());
|
||||
|
||||
// Launch enough thread groups for each output
|
||||
n_threads = thread_group_size;
|
||||
size_t n_threads = thread_group_size;
|
||||
grid_dims = MTL::Size(n_threads, out.size(), 1);
|
||||
group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
|
||||
@ -417,11 +454,12 @@ void strided_reduce_general_dispatch(
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(
|
||||
new_shape.data(), new_shape.size() * sizeof(int), 4);
|
||||
new_shape.data(), new_shape.size() * sizeof(int), 5);
|
||||
compute_encoder->setBytes(
|
||||
new_strides.data(), new_strides.size() * sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 6);
|
||||
new_strides.data(), new_strides.size() * sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 7);
|
||||
|
||||
// Each thread group is responsible for 1 output
|
||||
size_t n_reads = REDUCE_N_READS;
|
||||
|
Loading…
Reference in New Issue
Block a user