MLX
 
Loading...
Searching...
No Matches
scan.h File Reference

Go to the source code of this file.

Classes

struct  CumSum< U >
 
struct  CumProd< U >
 
struct  CumProd< bool >
 
struct  CumMax< U >
 
struct  CumMin< U >
 

Macros

#define DEFINE_SIMD_SCAN()
 
#define DEFINE_SIMD_EXCLUSIVE_SCAN()
 

Functions

template<typename T, typename U, int N_READS, bool reverse>
void load_unsafe (U values[N_READS], const device T *input)
 
template<typename T, typename U, int N_READS, bool reverse>
void load_safe (U values[N_READS], const device T *input, int start, int total, U init)
 
template<typename U, int N_READS, bool reverse>
void write_unsafe (U values[N_READS], device U *out)
 
template<typename U, int N_READS, bool reverse>
void write_safe (U values[N_READS], device U *out, int start, int total)
 
template<typename T, typename U, typename Op, int N_READS, bool inclusive, bool reverse>
void contiguous_scan (const device T *in, device U *out, const constant size_t &axis_size, uint3 gid, uint3 gsize, uint3 lid, uint3 lsize, uint simd_lane_id, uint simd_group_id)
 
template<typename T, typename U, typename Op, int N_READS, bool inclusive, bool reverse>
void strided_scan (const device T *in, device U *out, const constant size_t &axis_size, const constant size_t &stride, const constant size_t &stride_blocks, uint3 gid, uint3 gsize, uint3 lid, uint simd_lane_id, uint simd_group_id)
 

Macro Definition Documentation

◆ DEFINE_SIMD_EXCLUSIVE_SCAN

#define DEFINE_SIMD_EXCLUSIVE_SCAN ( )
Value:
template <typename T, metal::enable_if_t<sizeof(T) < 8, bool> = true> \
T simd_exclusive_scan(T val) { \
return simd_exclusive_scan_impl(val); \
} \
\
template <typename T, metal::enable_if_t<sizeof(T) == 8, bool> = true> \
T simd_exclusive_scan(T val) { \
val = simd_scan(val); \
return simd_shuffle_and_fill_up(val, init, 1); \
}
uint64_t simd_shuffle_and_fill_up(uint64_t data, uint64_t filling, uint16_t delta)
Definition utils.h:372

◆ DEFINE_SIMD_SCAN

#define DEFINE_SIMD_SCAN ( )
Value:
template <typename T, metal::enable_if_t<sizeof(T) < 8, bool> = true> \
T simd_scan(T val) { \
return simd_scan_impl(val); \
} \
\
template <typename T, metal::enable_if_t<sizeof(T) == 8, bool> = true> \
T simd_scan(T val) { \
for (int i = 1; i <= 16; i *= 2) { \
val = operator()(val, simd_shuffle_and_fill_up(val, init, i)); \
} \
return val; \
}

Function Documentation

◆ contiguous_scan()

template<typename T, typename U, typename Op, int N_READS, bool inclusive, bool reverse>
void contiguous_scan ( const device T * in,
device U * out,
const constant size_t & axis_size,
uint3 gid,
uint3 gsize,
uint3 lid,
uint3 lsize,
uint simd_lane_id,
uint simd_group_id )

◆ load_safe()

template<typename T, typename U, int N_READS, bool reverse>
void load_safe ( U values[N_READS],
const device T * input,
int start,
int total,
U init )
inline

◆ load_unsafe()

template<typename T, typename U, int N_READS, bool reverse>
void load_unsafe ( U values[N_READS],
const device T * input )
inline

◆ strided_scan()

template<typename T, typename U, typename Op, int N_READS, bool inclusive, bool reverse>
void strided_scan ( const device T * in,
device U * out,
const constant size_t & axis_size,
const constant size_t & stride,
const constant size_t & stride_blocks,
uint3 gid,
uint3 gsize,
uint3 lid,
uint simd_lane_id,
uint simd_group_id )

◆ write_safe()

template<typename U, int N_READS, bool reverse>
void write_safe ( U values[N_READS],
device U * out,
int start,
int total )
inline

◆ write_unsafe()

template<typename U, int N_READS, bool reverse>
void write_unsafe ( U values[N_READS],
device U * out )
inline