MLX
Loading...
Searching...
No Matches
Functions
reduce_row.h File Reference

Go to the source code of this file.

Functions

template<typename T , typename U , typename Op >
void row_reduce_general_small (const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &out_size, const constant size_t &non_row_reductions, const constant int *shape, const constant size_t *strides, const constant int &ndim, uint lid)
 
template<typename T , typename U , typename Op >
void row_reduce_general_med (const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &out_size, const constant size_t &non_row_reductions, const constant int *shape, const constant size_t *strides, const constant int &ndim, uint tid, uint simd_lane_id, uint simd_per_group, uint simd_group_id)
 
template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS>
METAL_FUNC U per_thread_row_reduce (const device T *in, const constant size_t &reduction_size, const constant size_t &out_size, const constant int *shape, const constant size_t *strides, const constant int &ndim, uint lsize_x, uint lid_x, uint2 tid)
 
template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS>
void row_reduce_general (const device T *in, device mlx_atomic< U > *out, const constant size_t &reduction_size, const constant size_t &out_size, const constant size_t &non_row_reductions, const constant int *shape, const constant size_t *strides, const constant int &ndim, uint3 lid, uint3 lsize, uint3 tid, uint simd_lane_id, uint simd_per_group, uint simd_group_id)
 
template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS>
void row_reduce_general_no_atomics (const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &out_size, const constant size_t &non_row_reductions, const constant int *shape, const constant size_t *strides, const constant int &ndim, uint3 lid, uint3 lsize, uint3 gsize, uint3 tid, uint simd_lane_id, uint simd_per_group, uint simd_group_id)
 

Function Documentation

◆ per_thread_row_reduce()

template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS>
METAL_FUNC U per_thread_row_reduce ( const device T * in,
const constant size_t & reduction_size,
const constant size_t & out_size,
const constant int * shape,
const constant size_t * strides,
const constant int & ndim,
uint lsize_x,
uint lid_x,
uint2 tid )

◆ row_reduce_general()

template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS>
void row_reduce_general ( const device T * in,
device mlx_atomic< U > * out,
const constant size_t & reduction_size,
const constant size_t & out_size,
const constant size_t & non_row_reductions,
const constant int * shape,
const constant size_t * strides,
const constant int & ndim,
uint3 lid,
uint3 lsize,
uint3 tid,
uint simd_lane_id,
uint simd_per_group,
uint simd_group_id )

◆ row_reduce_general_med()

template<typename T , typename U , typename Op >
void row_reduce_general_med ( const device T * in,
device U * out,
const constant size_t & reduction_size,
const constant size_t & out_size,
const constant size_t & non_row_reductions,
const constant int * shape,
const constant size_t * strides,
const constant int & ndim,
uint tid,
uint simd_lane_id,
uint simd_per_group,
uint simd_group_id )

◆ row_reduce_general_no_atomics()

template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS>
void row_reduce_general_no_atomics ( const device T * in,
device U * out,
const constant size_t & reduction_size,
const constant size_t & out_size,
const constant size_t & non_row_reductions,
const constant int * shape,
const constant size_t * strides,
const constant int & ndim,
uint3 lid,
uint3 lsize,
uint3 gsize,
uint3 tid,
uint simd_lane_id,
uint simd_per_group,
uint simd_group_id )

◆ row_reduce_general_small()

template<typename T , typename U , typename Op >
void row_reduce_general_small ( const device T * in,
device U * out,
const constant size_t & reduction_size,
const constant size_t & out_size,
const constant size_t & non_row_reductions,
const constant int * shape,
const constant size_t * strides,
const constant int & ndim,
uint lid )