MLX
Loading...
Searching...
No Matches
Functions
reduce_col.h File Reference

Go to the source code of this file.

Functions

template<typename T , typename U , typename Op >
void col_reduce_small (const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant size_t &out_size, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant size_t &non_col_reductions, const constant int *non_col_shapes, const constant size_t *non_col_strides, const constant int &non_col_ndim, uint tid)
 
template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS>
METAL_FUNC U _contiguous_strided_reduce (const device T *in, threadgroup U *local_data, uint in_idx, uint reduction_size, uint reduction_stride, uint2 tid, uint2 lid, uint2 lsize)
 
template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS>
void col_reduce_general (const device T *in, device mlx_atomic< U > *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant size_t &out_size, const constant int *shape, const constant size_t *strides, const constant int &ndim, threadgroup U *local_data, uint3 tid, uint3 lid, uint3 lsize)
 
template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS>
void col_reduce_general_no_atomics (const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant size_t &out_size, const constant int *shape, const constant size_t *strides, const constant int &ndim, threadgroup U *local_data, uint3 tid, uint3 lid, uint3 gid, uint3 lsize, uint3 gsize)
 

Function Documentation

◆ _contiguous_strided_reduce()

template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS>
METAL_FUNC U _contiguous_strided_reduce ( const device T * in,
threadgroup U * local_data,
uint in_idx,
uint reduction_size,
uint reduction_stride,
uint2 tid,
uint2 lid,
uint2 lsize )

◆ col_reduce_general()

template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS>
void col_reduce_general ( const device T * in,
device mlx_atomic< U > * out,
const constant size_t & reduction_size,
const constant size_t & reduction_stride,
const constant size_t & out_size,
const constant int * shape,
const constant size_t * strides,
const constant int & ndim,
threadgroup U * local_data,
uint3 tid,
uint3 lid,
uint3 lsize )

◆ col_reduce_general_no_atomics()

template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS>
void col_reduce_general_no_atomics ( const device T * in,
device U * out,
const constant size_t & reduction_size,
const constant size_t & reduction_stride,
const constant size_t & out_size,
const constant int * shape,
const constant size_t * strides,
const constant int & ndim,
threadgroup U * local_data,
uint3 tid,
uint3 lid,
uint3 gid,
uint3 lsize,
uint3 gsize )

◆ col_reduce_small()

template<typename T , typename U , typename Op >
void col_reduce_small ( const device T * in,
device U * out,
const constant size_t & reduction_size,
const constant size_t & reduction_stride,
const constant size_t & out_size,
const constant int * shape,
const constant size_t * strides,
const constant int & ndim,
const constant size_t & non_col_reductions,
const constant int * non_col_shapes,
const constant size_t * non_col_strides,
const constant int & non_col_ndim,
uint tid )