Go to the source code of this file.
|
template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS> |
METAL_FUNC U | per_thread_all_reduce (const device T *in, const device size_t &in_size, uint gid, uint grid_size) |
|
template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS> |
void | all_reduce (const device T *in, device mlx_atomic< U > *out, const device size_t &in_size, uint gid, uint lid, uint grid_size, uint simd_per_group, uint simd_lane_id, uint simd_group_id) |
|
template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS> |
void | all_reduce_no_atomics (const device T *in, device U *out, const device size_t &in_size, uint gid, uint lid, uint grid_size, uint simd_per_group, uint simd_lane_id, uint simd_group_id, uint thread_group_id) |
|
◆ all_reduce()
template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS>
void all_reduce |
( |
const device T * | in, |
|
|
device mlx_atomic< U > * | out, |
|
|
const device size_t & | in_size, |
|
|
uint | gid, |
|
|
uint | lid, |
|
|
uint | grid_size, |
|
|
uint | simd_per_group, |
|
|
uint | simd_lane_id, |
|
|
uint | simd_group_id ) |
◆ all_reduce_no_atomics()
template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS>
void all_reduce_no_atomics |
( |
const device T * | in, |
|
|
device U * | out, |
|
|
const device size_t & | in_size, |
|
|
uint | gid, |
|
|
uint | lid, |
|
|
uint | grid_size, |
|
|
uint | simd_per_group, |
|
|
uint | simd_lane_id, |
|
|
uint | simd_group_id, |
|
|
uint | thread_group_id ) |
◆ per_thread_all_reduce()
template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS>
METAL_FUNC U per_thread_all_reduce |
( |
const device T * | in, |
|
|
const device size_t & | in_size, |
|
|
uint | gid, |
|
|
uint | grid_size ) |