MLX
Loading...
Searching...
No Matches
Functions
reduce_all.h File Reference

Go to the source code of this file.

Functions

template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS>
METAL_FUNC U per_thread_all_reduce (const device T *in, const device size_t &in_size, uint gid, uint grid_size)
 
template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS>
void all_reduce (const device T *in, device mlx_atomic< U > *out, const device size_t &in_size, uint gid, uint lid, uint grid_size, uint simd_per_group, uint simd_lane_id, uint simd_group_id)
 
template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS>
void all_reduce_no_atomics (const device T *in, device U *out, const device size_t &in_size, uint gid, uint lid, uint grid_size, uint simd_per_group, uint simd_lane_id, uint simd_group_id, uint thread_group_id)
 

Function Documentation

◆ all_reduce()

template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS>
void all_reduce ( const device T * in,
device mlx_atomic< U > * out,
const device size_t & in_size,
uint gid,
uint lid,
uint grid_size,
uint simd_per_group,
uint simd_lane_id,
uint simd_group_id )

◆ all_reduce_no_atomics()

template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS>
void all_reduce_no_atomics ( const device T * in,
device U * out,
const device size_t & in_size,
uint gid,
uint lid,
uint grid_size,
uint simd_per_group,
uint simd_lane_id,
uint simd_group_id,
uint thread_group_id )

◆ per_thread_all_reduce()

template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS>
METAL_FUNC U per_thread_all_reduce ( const device T * in,
const device size_t & in_size,
uint gid,
uint grid_size )