MLX
Loading...
Searching...
No Matches
reduce_col.h File Reference

Go to the source code of this file.

Functions

template<typename T , typename U , typename Op , int NDIMS, int N_READS = REDUCE_N_READS>
void col_reduce_small (const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, uint3 gid, uint3 gsize, uint simd_lane_id, uint simd_group_id, uint3 tid, uint3 tsize)
 
template<typename T , typename U , typename Op , int NDIMS, int BM, int BN>
void col_reduce_looped (const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, uint3 gid, uint3 gsize, uint simd_lane_id, uint simd_group_id)
 Our approach is the following simple looped approach:
 

Function Documentation

◆ col_reduce_looped()

template<typename T , typename U , typename Op , int NDIMS, int BM, int BN>
void col_reduce_looped ( const device T * in,
device U * out,
const constant size_t & reduction_size,
const constant size_t & reduction_stride,
const constant int * shape,
const constant size_t * strides,
const constant int & ndim,
const constant int * reduce_shape,
const constant size_t * reduce_strides,
const constant int & reduce_ndim,
const constant size_t & non_col_reductions,
uint3 gid,
uint3 gsize,
uint simd_lane_id,
uint simd_group_id )

Our approach is the following simple looped approach:

  1. Each thread keeps running totals for BN / n_simdgroups outputs.
  2. Load a tile BM, BN in registers and accumulate in the running totals
  3. Move ahead by BM steps until the column axis and the non column reductions are exhausted.
  4. If BM == 32 then transpose in SM and simd reduce the running totals. Otherwise write in shared memory and BN threads accumulate the running totals with a loop.
  5. Write them to the output

◆ col_reduce_small()

template<typename T , typename U , typename Op , int NDIMS, int N_READS = REDUCE_N_READS>
void col_reduce_small ( const device T * in,
device U * out,
const constant size_t & reduction_size,
const constant size_t & reduction_stride,
const constant int * shape,
const constant size_t * strides,
const constant int & ndim,
const constant int * reduce_shape,
const constant size_t * reduce_strides,
const constant int & reduce_ndim,
const constant size_t & non_col_reductions,
uint3 gid,
uint3 gsize,
uint simd_lane_id,
uint simd_group_id,
uint3 tid,
uint3 tsize )