mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-24 10:51:21 +08:00
620 lines
23 KiB
Metal
620 lines
23 KiB
Metal
// Copyright © 2023 Apple Inc.
|
|
|
|
#include <metal_atomic>
|
|
#include <metal_simdgroup>
|
|
|
|
#include "mlx/backend/metal/kernels/defines.h"
|
|
#include "mlx/backend/metal/kernels/reduce.h"
|
|
#include "mlx/backend/metal/kernels/utils.h"
|
|
|
|
using namespace metal;
|
|
|
|
static constant uint8_t simd_size = 32;
|
|
|
|
template <typename T, typename Op>
|
|
[[kernel]] void init_reduce(
|
|
device T *out [[buffer(0)]],
|
|
uint tid [[thread_position_in_grid]]) {
|
|
out[tid] = Op::init;
|
|
}
|
|
|
|
#define instantiate_init_reduce(name, otype, op) \
|
|
template [[host_name("i" #name)]] \
|
|
[[kernel]] void init_reduce<otype, op>( \
|
|
device otype *out [[buffer(1)]], \
|
|
uint tid [[thread_position_in_grid]]);
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
// All reduce
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|
inline U per_thread_all_reduce(
|
|
const device T *in,
|
|
const device size_t& in_size,
|
|
uint gid,
|
|
uint grid_size) {
|
|
Op op;
|
|
U total_val = Op::init;
|
|
|
|
if (gid * N_READS < in_size) {
|
|
in += gid * N_READS;
|
|
|
|
int r = 0;
|
|
for(; r < (int)ceildiv(in_size, grid_size * N_READS) - 1; r++) {
|
|
U vals[N_READS] = {op.init};
|
|
|
|
for(int i = 0; i < N_READS; i++) {
|
|
vals[i] = static_cast<U>(in[i]);
|
|
}
|
|
for(int i = 0; i < N_READS; i++) {
|
|
total_val = op(vals[i], total_val);
|
|
}
|
|
|
|
in += grid_size * N_READS;
|
|
}
|
|
|
|
// Separate case for the last set as we close the reduction size
|
|
size_t curr_idx = (gid + r * (size_t)grid_size) * N_READS;
|
|
if (curr_idx < in_size) {
|
|
int max_reads = in_size - curr_idx;
|
|
T vals[N_READS];
|
|
|
|
for(int i = 0, idx = 0; i < N_READS; i++, idx++) {
|
|
idx = idx < max_reads ? idx : max_reads - 1;
|
|
vals[i] = in[idx];
|
|
}
|
|
for(int i = 0; i < N_READS; i++) {
|
|
U val = i < max_reads ? vals[i] : Op::init;
|
|
total_val = op(static_cast<U>(val), total_val);
|
|
}
|
|
}
|
|
}
|
|
|
|
return total_val;
|
|
}
|
|
|
|
// NB: This kernel assumes threads_per_threadgroup is at most
|
|
// 1024. This way with a simd_size of 32, we are guaranteed to
|
|
// complete the reduction in two steps of simd-level reductions.
|
|
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|
[[kernel]] void all_reduce(
|
|
const device T *in [[buffer(0)]],
|
|
device mlx_atomic<U> *out [[buffer(1)]],
|
|
const device size_t& in_size [[buffer(2)]],
|
|
uint gid [[thread_position_in_grid]],
|
|
uint lid [[thread_position_in_threadgroup]],
|
|
uint grid_size [[threads_per_grid]],
|
|
uint simd_per_group [[simdgroups_per_threadgroup]],
|
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
|
|
|
Op op;
|
|
threadgroup U local_vals[simd_size];
|
|
|
|
U total_val = per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
|
|
|
|
// Reduction within simd group
|
|
total_val = op.simd_reduce(total_val);
|
|
if (simd_lane_id == 0) {
|
|
local_vals[simd_group_id] = total_val;
|
|
}
|
|
|
|
// Reduction within thread group
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
|
|
total_val = op.simd_reduce(total_val);
|
|
|
|
// Reduction across threadgroups
|
|
if (lid == 0) {
|
|
op.atomic_update(out, total_val);
|
|
}
|
|
}
|
|
|
|
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|
[[kernel]] void all_reduce_no_atomics(
|
|
const device T *in [[buffer(0)]],
|
|
device U *out [[buffer(1)]],
|
|
const device size_t& in_size [[buffer(2)]],
|
|
uint gid [[thread_position_in_grid]],
|
|
uint lid [[thread_position_in_threadgroup]],
|
|
uint grid_size [[threads_per_grid]],
|
|
uint simd_per_group [[simdgroups_per_threadgroup]],
|
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
uint thread_group_id [[threadgroup_position_in_grid]]) {
|
|
|
|
Op op;
|
|
threadgroup U local_vals[simd_size];
|
|
|
|
U total_val = per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
|
|
|
|
// Reduction within simd group (simd_add isn't supported for uint64/int64 types)
|
|
for (uint16_t lane_offset = simd_size/2; lane_offset > 0; lane_offset /= 2) {
|
|
total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
|
|
}
|
|
// Write simd group reduction results to local memory
|
|
if (simd_lane_id == 0) {
|
|
local_vals[simd_group_id] = total_val;
|
|
}
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
// Reduction of simdgroup reduction results within threadgroup.
|
|
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
|
|
for (uint16_t lane_offset = simd_size/2; lane_offset > 0; lane_offset /= 2) {
|
|
total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
|
|
}
|
|
|
|
// Reduction across threadgroups
|
|
if (lid == 0) {
|
|
out[thread_group_id] = total_val;
|
|
}
|
|
}
|
|
|
|
#define instantiate_all_reduce(name, itype, otype, op) \
|
|
template [[host_name("all_reduce_" #name)]] \
|
|
[[kernel]] void all_reduce<itype, otype, op>( \
|
|
const device itype *in [[buffer(0)]], \
|
|
device mlx_atomic<otype> *out [[buffer(1)]], \
|
|
const device size_t& in_size [[buffer(2)]], \
|
|
uint gid [[thread_position_in_grid]], \
|
|
uint lid [[thread_position_in_threadgroup]], \
|
|
uint grid_size [[threads_per_grid]], \
|
|
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
|
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
|
|
|
#define instantiate_all_reduce_no_atomics(name, itype, otype, op) \
|
|
template [[host_name("all_reduce_no_atomics_" #name)]] \
|
|
[[kernel]] void all_reduce_no_atomics<itype, otype, op>( \
|
|
const device itype *in [[buffer(0)]], \
|
|
device otype *out [[buffer(1)]], \
|
|
const device size_t& in_size [[buffer(2)]], \
|
|
uint gid [[thread_position_in_grid]], \
|
|
uint lid [[thread_position_in_threadgroup]], \
|
|
uint grid_size [[threads_per_grid]], \
|
|
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
|
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
|
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
|
uint thread_group_id [[threadgroup_position_in_grid]]);
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
// Row atomics
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|
inline U per_thread_row_reduce(
|
|
const device T *in,
|
|
const constant size_t& reduction_size,
|
|
const constant size_t& out_size,
|
|
const constant int* shape,
|
|
const constant size_t* strides,
|
|
const constant int& ndim,
|
|
uint lsize_x,
|
|
uint lid_x,
|
|
uint2 tid) {
|
|
|
|
Op op;
|
|
|
|
// Each threadgroup handles 1 reduction
|
|
// TODO: Specializing elem_to_loc would be slightly faster
|
|
int idx = tid.y * out_size + tid.x;
|
|
int extra_offset = elem_to_loc(idx, shape, strides, ndim);
|
|
in += extra_offset + lid_x * N_READS;
|
|
|
|
// The reduction is accumulated here
|
|
U total_val = Op::init;
|
|
|
|
// Loop over the reduction size within thread group
|
|
int r = 0;
|
|
for (; r < (int)ceildiv(reduction_size, N_READS*lsize_x) - 1; r++) {
|
|
T vals[N_READS];
|
|
for(int i = 0; i < N_READS; i++) {
|
|
vals[i] = in[i];
|
|
}
|
|
for(int i = 0; i < N_READS; i++) {
|
|
total_val = op(static_cast<U>(vals[i]), total_val);
|
|
}
|
|
|
|
in += lsize_x * N_READS;
|
|
}
|
|
|
|
// Separate case for the last set as we close the reduction size
|
|
size_t reduction_index = (lid_x + (size_t)lsize_x * r) * N_READS;
|
|
if(reduction_index < reduction_size) {
|
|
int max_reads = reduction_size - reduction_index;
|
|
|
|
T vals[N_READS];
|
|
for(int i = 0; i < N_READS; i++) {
|
|
int idx = min(i, max_reads - 1);
|
|
vals[i] = static_cast<U>(in[idx]);
|
|
}
|
|
for(int i = 0; i < N_READS; i++) {
|
|
T val = i < max_reads ? vals[i] : Op::init;
|
|
total_val = op(static_cast<U>(val), total_val);
|
|
}
|
|
}
|
|
|
|
return total_val;
|
|
}
|
|
|
|
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|
[[kernel]] void row_reduce_general(
|
|
const device T *in [[buffer(0)]],
|
|
device mlx_atomic<U> *out [[buffer(1)]],
|
|
const constant size_t& reduction_size [[buffer(2)]],
|
|
const constant size_t& out_size [[buffer(3)]],
|
|
const constant int* shape [[buffer(4)]],
|
|
const constant size_t* strides [[buffer(5)]],
|
|
const constant int& ndim [[buffer(6)]],
|
|
uint3 lid [[thread_position_in_threadgroup]],
|
|
uint3 lsize [[threads_per_threadgroup]],
|
|
uint3 tid [[threadgroup_position_in_grid]],
|
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
uint simd_per_group [[simdgroups_per_threadgroup]],
|
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
|
|
|
Op op;
|
|
threadgroup U local_vals[simd_size];
|
|
|
|
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(in, reduction_size, out_size, shape, strides, ndim, lsize.x, lid.x, tid.xy);
|
|
|
|
total_val = op.simd_reduce(total_val);
|
|
|
|
// Prepare next level
|
|
if (simd_lane_id == 0) {
|
|
local_vals[simd_group_id] = total_val;
|
|
}
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
// Reduction within thread group
|
|
// Only needed if multiple simd groups
|
|
if(reduction_size > simd_size) {
|
|
total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init;
|
|
total_val = op.simd_reduce(total_val);
|
|
}
|
|
// Update output
|
|
if (lid.x == 0) {
|
|
op.atomic_update(out, total_val, tid.x);
|
|
}
|
|
}
|
|
|
|
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
|
[[kernel]] void row_reduce_general_no_atomics(
|
|
const device T *in [[buffer(0)]],
|
|
device U *out [[buffer(1)]],
|
|
const constant size_t& reduction_size [[buffer(2)]],
|
|
const constant size_t& out_size [[buffer(3)]],
|
|
const constant int* shape [[buffer(4)]],
|
|
const constant size_t* strides [[buffer(5)]],
|
|
const constant int& ndim [[buffer(6)]],
|
|
uint3 lid [[thread_position_in_threadgroup]],
|
|
uint3 lsize [[threads_per_threadgroup]],
|
|
uint3 gsize [[threads_per_grid]],
|
|
uint3 tid [[threadgroup_position_in_grid]],
|
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
uint simd_per_group [[simdgroups_per_threadgroup]],
|
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
|
|
|
Op op;
|
|
|
|
threadgroup U local_vals[simd_size];
|
|
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(in, reduction_size, out_size, shape, strides, ndim, lsize.x, lid.x, tid.xy);
|
|
|
|
// Reduction within simd group - simd_add isn't supported for int64 types
|
|
for (uint16_t i = simd_size/2; i > 0; i /= 2) {
|
|
total_val = op(total_val, simd_shuffle_down(total_val, i));
|
|
}
|
|
|
|
// Prepare next level
|
|
if (simd_lane_id == 0) {
|
|
local_vals[simd_group_id] = total_val;
|
|
}
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
// Reduction within thread group
|
|
// Only needed if thread group has multiple simd groups
|
|
if(ceildiv(reduction_size, N_READS) > simd_size) {
|
|
total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init;
|
|
for (uint16_t i = simd_size/2; i > 0; i /= 2) {
|
|
total_val = op(total_val, simd_shuffle_down(total_val, i));
|
|
}
|
|
}
|
|
// Write row reduce output for threadgroup with 1st thread in thread group
|
|
if (lid.x == 0) {
|
|
out[(ceildiv(gsize.y, lsize.y) * tid.x) + tid.y] = total_val;
|
|
}
|
|
}
|
|
|
|
#define instantiate_row_reduce_general(name, itype, otype, op) \
|
|
template [[host_name("row_reduce_general_" #name)]] \
|
|
[[kernel]] void row_reduce_general<itype, otype, op>( \
|
|
const device itype *in [[buffer(0)]], \
|
|
device mlx_atomic<otype> *out [[buffer(1)]], \
|
|
const constant size_t& reduction_size [[buffer(2)]], \
|
|
const constant size_t& out_size [[buffer(3)]], \
|
|
const constant int* shape [[buffer(4)]], \
|
|
const constant size_t* strides [[buffer(5)]], \
|
|
const constant int& ndim [[buffer(6)]], \
|
|
uint3 lid [[thread_position_in_threadgroup]], \
|
|
uint3 lsize [[threads_per_threadgroup]], \
|
|
uint3 tid [[threadgroup_position_in_grid]], \
|
|
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
|
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
|
|
|
#define instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \
|
|
template [[host_name("row_reduce_general_no_atomics_" #name)]] \
|
|
[[kernel]] void row_reduce_general_no_atomics<itype, otype, op>( \
|
|
const device itype *in [[buffer(0)]], \
|
|
device otype *out [[buffer(1)]], \
|
|
const constant size_t& reduction_size [[buffer(2)]], \
|
|
const constant size_t& out_size [[buffer(3)]], \
|
|
const constant int* shape [[buffer(4)]], \
|
|
const constant size_t* strides [[buffer(5)]], \
|
|
const constant int& ndim [[buffer(6)]], \
|
|
uint3 lid [[thread_position_in_threadgroup]], \
|
|
uint3 lsize [[threads_per_threadgroup]], \
|
|
uint3 gsize [[threads_per_grid]], \
|
|
uint3 tid [[threadgroup_position_in_grid]], \
|
|
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
|
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
// Column reduce
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
|
inline U _contiguous_strided_reduce(
|
|
const device T *in,
|
|
threadgroup U *local_data,
|
|
uint in_idx,
|
|
uint reduction_size,
|
|
uint reduction_stride,
|
|
uint2 tid,
|
|
uint2 lid,
|
|
uint2 lsize) {
|
|
|
|
Op op;
|
|
U total_val = Op::init;
|
|
|
|
uint base_offset = (tid.y * lsize.y + lid.y) * N_READS;
|
|
for(uint r = 0; r < N_READS && (base_offset + r) < reduction_size; r++) {
|
|
uint offset = base_offset + r;
|
|
total_val = op(static_cast<U>(total_val), in[in_idx + offset * reduction_stride]);
|
|
}
|
|
local_data[lsize.y * lid.x + lid.y] = total_val;
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
U val = Op::init;
|
|
if(lid.y == 0) {
|
|
// Perform reduction across columns in thread group
|
|
for(uint i = 0; i < lsize.y; i++) {
|
|
val = op(val, local_data[lsize.y * lid.x + i]);
|
|
}
|
|
}
|
|
|
|
return val;
|
|
}
|
|
|
|
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
|
[[kernel]] void col_reduce_general(
|
|
const device T *in [[buffer(0)]],
|
|
device mlx_atomic<U> *out [[buffer(1)]],
|
|
const constant size_t& reduction_size [[buffer(2)]],
|
|
const constant size_t& reduction_stride [[buffer(3)]],
|
|
const constant size_t& out_size [[buffer(4)]],
|
|
const constant int* shape [[buffer(5)]],
|
|
const constant size_t* strides [[buffer(6)]],
|
|
const constant int& ndim [[buffer(7)]],
|
|
threadgroup U *local_data [[threadgroup(0)]],
|
|
uint3 tid [[threadgroup_position_in_grid]],
|
|
uint3 lid [[thread_position_in_threadgroup]],
|
|
uint3 lsize [[threads_per_threadgroup]]) {
|
|
auto out_idx = tid.x * lsize.x + lid.x;
|
|
auto in_idx = elem_to_loc(
|
|
out_idx + tid.z * out_size,
|
|
shape,
|
|
strides,
|
|
ndim
|
|
);
|
|
|
|
Op op;
|
|
if(out_idx < out_size) {
|
|
U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
|
|
in,
|
|
local_data,
|
|
in_idx,
|
|
reduction_size,
|
|
reduction_stride,
|
|
tid.xy,
|
|
lid.xy,
|
|
lsize.xy);
|
|
|
|
// Write out reduction results generated by threadgroups working on specific output element, contiguously.
|
|
if (lid.y == 0) {
|
|
op.atomic_update(out, val, out_idx);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
|
[[kernel]] void col_reduce_general_no_atomics(
|
|
const device T *in [[buffer(0)]],
|
|
device U *out [[buffer(1)]],
|
|
const constant size_t& reduction_size [[buffer(2)]],
|
|
const constant size_t& reduction_stride [[buffer(3)]],
|
|
const constant size_t& out_size [[buffer(4)]],
|
|
const constant int* shape [[buffer(5)]],
|
|
const constant size_t* strides [[buffer(6)]],
|
|
const constant int& ndim [[buffer(7)]],
|
|
threadgroup U *local_data [[threadgroup(0)]],
|
|
uint3 tid [[threadgroup_position_in_grid]],
|
|
uint3 lid [[thread_position_in_threadgroup]],
|
|
uint3 gid [[thread_position_in_grid]],
|
|
uint3 lsize [[threads_per_threadgroup]],
|
|
uint3 gsize [[threads_per_grid]]) {
|
|
auto out_idx = tid.x * lsize.x + lid.x;
|
|
auto in_idx = elem_to_loc(
|
|
out_idx + tid.z * out_size,
|
|
shape,
|
|
strides,
|
|
ndim
|
|
);
|
|
|
|
if(out_idx < out_size) {
|
|
U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
|
|
in,
|
|
local_data,
|
|
in_idx,
|
|
reduction_size,
|
|
reduction_stride,
|
|
tid.xy,
|
|
lid.xy,
|
|
lsize.xy);
|
|
|
|
// Write out reduction results generated by threadgroups working on specific output element, contiguously.
|
|
if (lid.y == 0) {
|
|
uint tgsize_y = ceildiv(gsize.y, lsize.y);
|
|
uint tgsize_z = ceildiv(gsize.z, lsize.z);
|
|
out[tgsize_y * tgsize_z * gid.x + tgsize_y * tid.z + tid.y] = val;
|
|
}
|
|
}
|
|
}
|
|
|
|
#define instantiate_col_reduce_general(name, itype, otype, op) \
|
|
template [[host_name("col_reduce_general_" #name)]] \
|
|
[[kernel]] void col_reduce_general<itype, otype, op>( \
|
|
const device itype *in [[buffer(0)]], \
|
|
device mlx_atomic<otype> *out [[buffer(1)]], \
|
|
const constant size_t& reduction_size [[buffer(2)]], \
|
|
const constant size_t& reduction_stride [[buffer(3)]], \
|
|
const constant size_t& out_size [[buffer(4)]], \
|
|
const constant int* shape [[buffer(5)]], \
|
|
const constant size_t* strides [[buffer(6)]], \
|
|
const constant int& ndim [[buffer(7)]], \
|
|
threadgroup otype *local_data [[threadgroup(0)]], \
|
|
uint3 tid [[threadgroup_position_in_grid]], \
|
|
uint3 lid [[thread_position_in_threadgroup]], \
|
|
uint3 lsize [[threads_per_threadgroup]]);
|
|
|
|
#define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \
|
|
template [[host_name("col_reduce_general_no_atomics_" #name)]] \
|
|
[[kernel]] void col_reduce_general_no_atomics<itype, otype, op>( \
|
|
const device itype *in [[buffer(0)]], \
|
|
device otype *out [[buffer(1)]], \
|
|
const constant size_t& reduction_size [[buffer(2)]], \
|
|
const constant size_t& reduction_stride [[buffer(3)]], \
|
|
const constant size_t& out_size [[buffer(4)]], \
|
|
const constant int* shape [[buffer(5)]], \
|
|
const constant size_t* strides [[buffer(6)]], \
|
|
const constant int& ndim [[buffer(7)]], \
|
|
threadgroup otype *local_data [[threadgroup(0)]], \
|
|
uint3 tid [[threadgroup_position_in_grid]], \
|
|
uint3 lid [[thread_position_in_threadgroup]], \
|
|
uint3 gid [[thread_position_in_grid]], \
|
|
uint3 lsize [[threads_per_threadgroup]], \
|
|
uint3 gsize [[threads_per_grid]]);
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
// Instantiations
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
#define instantiate_reduce(name, itype, otype, op) \
|
|
instantiate_all_reduce(name, itype, otype, op) \
|
|
instantiate_row_reduce_general(name, itype, otype, op) \
|
|
instantiate_col_reduce_general(name, itype, otype, op)
|
|
|
|
#define instantiate_reduce_no_atomics(name, itype, otype, op) \
|
|
instantiate_all_reduce_no_atomics(name, itype, otype, op) \
|
|
instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \
|
|
instantiate_col_reduce_general_no_atomics(name, itype, otype, op)
|
|
|
|
#define instantiate_same_reduce_no_atomics(name, tname, type, op) \
|
|
instantiate_init_reduce(name ##tname, type, op<type>) \
|
|
instantiate_reduce_no_atomics(name ##tname, type, type, op<type>)
|
|
|
|
#define instantiate_same_reduce(name, tname, type, op) \
|
|
instantiate_init_reduce(name ##tname, type, op<type>) \
|
|
instantiate_reduce(name ##tname, type, type, op<type>)
|
|
|
|
#define instantiate_reduce_from_types_helper(name, tname, itype, otype, op) \
|
|
instantiate_reduce(name ##tname, itype, otype, op)
|
|
|
|
#define instantiate_reduce_from_types(name, otype, op) \
|
|
instantiate_reduce_from_types_helper(name, bool_, bool, otype, op) \
|
|
instantiate_reduce_from_types_helper(name, uint8, uint8_t, otype, op) \
|
|
instantiate_reduce_from_types_helper(name, uint16, uint16_t, otype, op) \
|
|
instantiate_reduce_from_types_helper(name, uint32, uint32_t, otype, op) \
|
|
instantiate_reduce_from_types_helper(name, int8, int8_t, otype, op) \
|
|
instantiate_reduce_from_types_helper(name, int16, int16_t, otype, op) \
|
|
instantiate_reduce_from_types_helper(name, int32, int32_t, otype, op) \
|
|
instantiate_reduce_from_types_helper(name, int64, int64_t, otype, op) \
|
|
instantiate_reduce_from_types_helper(name, float16, half, otype, op) \
|
|
instantiate_reduce_from_types_helper(name, float32, float, otype, op) \
|
|
instantiate_reduce_from_types_helper(name, bfloat16, bfloat16_t, otype, op)
|
|
|
|
// special case bool with larger output type
|
|
instantiate_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
|
instantiate_same_reduce(sum, uint8, uint8_t, Sum)
|
|
instantiate_same_reduce(sum, uint16, uint16_t, Sum)
|
|
instantiate_same_reduce(sum, uint32, uint32_t, Sum)
|
|
instantiate_same_reduce(sum, int8, int8_t, Sum)
|
|
instantiate_same_reduce(sum, int16, int16_t, Sum)
|
|
instantiate_same_reduce(sum, int32, int32_t, Sum)
|
|
instantiate_same_reduce(sum, float16, half, Sum)
|
|
instantiate_same_reduce(sum, float32, float, Sum)
|
|
|
|
instantiate_same_reduce_no_atomics(sum, int64, int64_t, Sum)
|
|
instantiate_same_reduce_no_atomics(sum, uint64, uint64_t, Sum)
|
|
|
|
instantiate_same_reduce(prod, uint8, uint8_t, Prod)
|
|
instantiate_same_reduce(prod, uint16, uint16_t, Prod)
|
|
instantiate_same_reduce(prod, uint32, uint32_t, Prod)
|
|
instantiate_same_reduce(prod, int8, int8_t, Prod)
|
|
instantiate_same_reduce(prod, int16, int16_t, Prod)
|
|
instantiate_same_reduce(prod, int32, int32_t, Prod)
|
|
instantiate_same_reduce(prod, float16, half, Prod)
|
|
instantiate_same_reduce(prod, float32, float, Prod)
|
|
|
|
instantiate_same_reduce_no_atomics(prod, int64, int64_t, Prod)
|
|
instantiate_same_reduce_no_atomics(prod, uint64, uint64_t, Prod)
|
|
|
|
instantiate_same_reduce(sum, bfloat16, bfloat16_t, Sum)
|
|
instantiate_same_reduce(prod, bfloat16, bfloat16_t, Prod)
|
|
|
|
instantiate_init_reduce(andbool_, bool, And)
|
|
instantiate_reduce_from_types(and, bool, And)
|
|
|
|
instantiate_init_reduce(orbool_, bool, Or)
|
|
instantiate_reduce_from_types(or, bool, Or)
|
|
|
|
// Compiler segfaulted with the names "min" or "max" ...
|
|
instantiate_same_reduce(min_, uint8, uint8_t, Min)
|
|
instantiate_same_reduce(min_, uint16, uint16_t, Min)
|
|
instantiate_same_reduce(min_, uint32, uint32_t, Min)
|
|
instantiate_same_reduce(min_, int8, int8_t, Min)
|
|
instantiate_same_reduce(min_, int16, int16_t, Min)
|
|
instantiate_same_reduce(min_, int32, int32_t, Min)
|
|
instantiate_same_reduce(min_, float16, half, Min)
|
|
instantiate_same_reduce(min_, float32, float, Min)
|
|
|
|
instantiate_same_reduce_no_atomics(min_, int64, int64_t, Min)
|
|
instantiate_same_reduce_no_atomics(min_, uint64, uint64_t, Min)
|
|
|
|
instantiate_same_reduce(max_, uint8, uint8_t, Max)
|
|
instantiate_same_reduce(max_, uint16, uint16_t, Max)
|
|
instantiate_same_reduce(max_, uint32, uint32_t, Max)
|
|
instantiate_same_reduce(max_, int8, int8_t, Max)
|
|
instantiate_same_reduce(max_, int16, int16_t, Max)
|
|
instantiate_same_reduce(max_, int32, int32_t, Max)
|
|
instantiate_same_reduce(max_, float16, half, Max)
|
|
instantiate_same_reduce(max_, float32, float, Max)
|
|
|
|
instantiate_same_reduce_no_atomics(max_, int64, int64_t, Max)
|
|
instantiate_same_reduce_no_atomics(max_, uint64, uint64_t, Max)
|
|
|
|
instantiate_same_reduce(min_, bfloat16, bfloat16_t, Min)
|
|
instantiate_same_reduce(max_, bfloat16, bfloat16_t, Max)
|