MLX
 
Loading...
Searching...
No Matches
reduce_col.h File Reference

Go to the source code of this file.

Functions

template<typename T, typename U, typename Op, typename IdxT, int NDIMS>
void col_reduce_small (const device T *in, device U *out, const constant size_t &reduction_size, const constant int64_t &reduction_stride, const constant int *shape, const constant int64_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant int64_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, uint3 gid, uint3 gsize, uint3 lid, uint3 lsize)
 
template<typename T, typename U, typename Op, typename IdxT, int NDIMS>
void col_reduce_longcolumn (const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant int *shape, const constant int64_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant int64_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, const constant size_t &out_size, uint3 gid, uint3 gsize, uint3 lid, uint3 lsize)
 
template<typename T, typename U, typename Op, typename IdxT, int NDIMS, int BM, int BN>
void col_reduce_looped (const device T *in, device U *out, const constant size_t &reduction_size, const constant int64_t &reduction_stride, const constant int *shape, const constant int64_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant int64_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, uint3 gid, uint3 gsize, uint simd_lane_id, uint simd_group_id)
 Our approach is the following simple looped approach:
 
template<typename T, typename U, typename Op, typename IdxT, int NDIMS, int BM, int BN>
void col_reduce_2pass (const device T *in, device U *out, const constant size_t &reduction_size, const constant int64_t &reduction_stride, const constant int *shape, const constant int64_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant int64_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, const constant size_t &out_size, uint3 gid, uint3 gsize, uint simd_lane_id, uint simd_group_id)
 

Function Documentation

◆ col_reduce_2pass()

template<typename T, typename U, typename Op, typename IdxT, int NDIMS, int BM, int BN>
void col_reduce_2pass ( const device T * in,
device U * out,
const constant size_t & reduction_size,
const constant int64_t & reduction_stride,
const constant int * shape,
const constant int64_t * strides,
const constant int & ndim,
const constant int * reduce_shape,
const constant int64_t * reduce_strides,
const constant int & reduce_ndim,
const constant size_t & non_col_reductions,
const constant size_t & out_size,
uint3 gid,
uint3 gsize,
uint simd_lane_id,
uint simd_group_id )

◆ col_reduce_longcolumn()

template<typename T, typename U, typename Op, typename IdxT, int NDIMS>
void col_reduce_longcolumn ( const device T * in,
device U * out,
const constant size_t & reduction_size,
const constant size_t & reduction_stride,
const constant int * shape,
const constant int64_t * strides,
const constant int & ndim,
const constant int * reduce_shape,
const constant int64_t * reduce_strides,
const constant int & reduce_ndim,
const constant size_t & non_col_reductions,
const constant size_t & out_size,
uint3 gid,
uint3 gsize,
uint3 lid,
uint3 lsize )

◆ col_reduce_looped()

template<typename T, typename U, typename Op, typename IdxT, int NDIMS, int BM, int BN>
void col_reduce_looped ( const device T * in,
device U * out,
const constant size_t & reduction_size,
const constant int64_t & reduction_stride,
const constant int * shape,
const constant int64_t * strides,
const constant int & ndim,
const constant int * reduce_shape,
const constant int64_t * reduce_strides,
const constant int & reduce_ndim,
const constant size_t & non_col_reductions,
uint3 gid,
uint3 gsize,
uint simd_lane_id,
uint simd_group_id )

Our approach is the following simple looped approach:

  1. Each thread keeps running totals for BN / n_simdgroups outputs.
  2. Load a tile BM, BN in registers and accumulate in the running totals
  3. Move ahead by BM steps until the column axis and the non column reductions are exhausted.
  4. If BM == 32 then transpose in SM and simd reduce the running totals. Otherwise write in shared memory and BN threads accumulate the running totals with a loop.
  5. Write them to the output

◆ col_reduce_small()

template<typename T, typename U, typename Op, typename IdxT, int NDIMS>
void col_reduce_small ( const device T * in,
device U * out,
const constant size_t & reduction_size,
const constant int64_t & reduction_stride,
const constant int * shape,
const constant int64_t * strides,
const constant int & ndim,
const constant int * reduce_shape,
const constant int64_t * reduce_strides,
const constant int & reduce_ndim,
const constant size_t & non_col_reductions,
uint3 gid,
uint3 gsize,
uint3 lid,
uint3 lsize )