MLX
Loading...
Searching...
No Matches
reduce_row.h File Reference

Go to the source code of this file.

Functions

template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS, int N_WRITES = REDUCE_N_WRITES>
METAL_FUNC void per_thread_row_reduce (thread U totals[N_WRITES], const device T *inputs[N_WRITES], int blocks, int extra, uint lsize_x, uint lid_x)
 The thread group collaboratively reduces across the rows with bounds checking.
 
template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS, int N_WRITES = REDUCE_N_WRITES>
METAL_FUNC void per_thread_row_reduce (thread U totals[N_WRITES], const device T *in, const constant size_t &reduction_size, int blocks, int extra, uint lsize_x, uint lid_x)
 Consecutive rows in a contiguous array.
 
template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS, int N_WRITES = REDUCE_N_WRITES>
METAL_FUNC void per_thread_row_reduce (thread U totals[N_WRITES], const device T *in, const size_t row_idx, int blocks, int extra, const constant int *shape, const constant size_t *strides, const constant int &ndim, uint lsize_x, uint lid_x)
 Consecutive rows in an arbitrarily ordered array.
 
template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS, int N_WRITES = REDUCE_N_WRITES>
METAL_FUNC void threadgroup_reduce (thread U totals[N_WRITES], threadgroup U *shared_vals, uint3 lid, uint simd_lane_id, uint simd_per_group, uint simd_group_id)
 Reduce within the threadgroup.
 
template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS>
METAL_FUNC void thread_reduce (thread U &total, const device T *row, int blocks, int extra)
 
template<typename T , typename U , typename Op , int NDIMS, int N_READS = REDUCE_N_READS>
void row_reduce_small (const device T *in, device U *out, const constant size_t &row_size, const constant size_t &non_row_reductions, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &reduce_ndim, uint simd_lane_id, uint3 gid, uint3 gsize, uint3 tid, uint3 tsize)
 
template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS, int N_WRITES = REDUCE_N_WRITES>
void row_reduce_simple (const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &out_size, uint3 gid, uint3 gsize, uint3 lid, uint3 lsize, uint simd_lane_id, uint simd_per_group, uint simd_group_id)
 
template<typename T , typename U , typename Op , int NDIMS, int N_READS = REDUCE_N_READS>
void row_reduce_looped (const device T *in, device U *out, const constant size_t &row_size, const constant size_t &non_row_reductions, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &reduce_ndim, uint3 gid, uint3 gsize, uint3 lid, uint3 lsize, uint simd_lane_id, uint simd_per_group, uint simd_group_id)
 

Function Documentation

◆ per_thread_row_reduce() [1/3]

template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS, int N_WRITES = REDUCE_N_WRITES>
METAL_FUNC void per_thread_row_reduce ( thread U totals[N_WRITES],
const device T * in,
const constant size_t & reduction_size,
int blocks,
int extra,
uint lsize_x,
uint lid_x )

Consecutive rows in a contiguous array.

◆ per_thread_row_reduce() [2/3]

template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS, int N_WRITES = REDUCE_N_WRITES>
METAL_FUNC void per_thread_row_reduce ( thread U totals[N_WRITES],
const device T * in,
const size_t row_idx,
int blocks,
int extra,
const constant int * shape,
const constant size_t * strides,
const constant int & ndim,
uint lsize_x,
uint lid_x )

Consecutive rows in an arbitrarily ordered array.

◆ per_thread_row_reduce() [3/3]

template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS, int N_WRITES = REDUCE_N_WRITES>
METAL_FUNC void per_thread_row_reduce ( thread U totals[N_WRITES],
const device T * inputs[N_WRITES],
int blocks,
int extra,
uint lsize_x,
uint lid_x )

The thread group collaboratively reduces across the rows with bounds checking.

In the end each thread holds a part of the reduction.

◆ row_reduce_looped()

template<typename T , typename U , typename Op , int NDIMS, int N_READS = REDUCE_N_READS>
void row_reduce_looped ( const device T * in,
device U * out,
const constant size_t & row_size,
const constant size_t & non_row_reductions,
const constant int * shape,
const constant size_t * strides,
const constant int & ndim,
const constant int * reduce_shape,
const constant size_t * reduce_strides,
const constant int & reduce_ndim,
uint3 gid,
uint3 gsize,
uint3 lid,
uint3 lsize,
uint simd_lane_id,
uint simd_per_group,
uint simd_group_id )

◆ row_reduce_simple()

template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS, int N_WRITES = REDUCE_N_WRITES>
void row_reduce_simple ( const device T * in,
device U * out,
const constant size_t & reduction_size,
const constant size_t & out_size,
uint3 gid,
uint3 gsize,
uint3 lid,
uint3 lsize,
uint simd_lane_id,
uint simd_per_group,
uint simd_group_id )

◆ row_reduce_small()

template<typename T , typename U , typename Op , int NDIMS, int N_READS = REDUCE_N_READS>
void row_reduce_small ( const device T * in,
device U * out,
const constant size_t & row_size,
const constant size_t & non_row_reductions,
const constant int * shape,
const constant size_t * strides,
const constant int & ndim,
const constant int * reduce_shape,
const constant size_t * reduce_strides,
const constant int & reduce_ndim,
uint simd_lane_id,
uint3 gid,
uint3 gsize,
uint3 tid,
uint3 tsize )

◆ thread_reduce()

template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS>
METAL_FUNC void thread_reduce ( thread U & total,
const device T * row,
int blocks,
int extra )

◆ threadgroup_reduce()

template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS, int N_WRITES = REDUCE_N_WRITES>
METAL_FUNC void threadgroup_reduce ( thread U totals[N_WRITES],
threadgroup U * shared_vals,
uint3 lid,
uint simd_lane_id,
uint simd_per_group,
uint simd_group_id )

Reduce within the threadgroup.