MLX
Loading...
Searching...
No Matches
Public Member Functions | Public Attributes | Static Public Attributes | List of all members
QuantizedBlockLoader< T, BROWS, BCOLS, dst_ld, reduction_dim, tgp_size, group_size, bits > Struct Template Reference

#include <quantized.h>

Public Member Functions

 QuantizedBlockLoader (const device uint32_t *src_, const device T *scales_, const device T *biases_, const int src_ld_, threadgroup T *dst_, ushort simd_group_id, ushort simd_lane_id)
 
void load_unsafe () const
 
void load_safe (short2 src_tile_dim) const
 
void next ()
 

Public Attributes

const int src_ld
 
const int tile_stride
 
short group_step_cnt
 
const int group_stride
 
const short thread_idx
 
const short bi
 
const short bj
 
threadgroup T * dst
 
const device uint32_t * src
 
const device T * scales
 
const device T * biases
 

Static Public Attributes

static constant constexpr const short pack_factor = 32 / bits
 
static constant constexpr const short BCOLS_PACKED = BCOLS / pack_factor
 
static constant constexpr const short n_reads
 
static constant constexpr const short group_steps = group_size / BCOLS
 

Constructor & Destructor Documentation

◆ QuantizedBlockLoader()

template<typename T , short BROWS, short BCOLS, short dst_ld, short reduction_dim, short tgp_size, short group_size, short bits>
QuantizedBlockLoader< T, BROWS, BCOLS, dst_ld, reduction_dim, tgp_size, group_size, bits >::QuantizedBlockLoader ( const device uint32_t * src_,
const device T * scales_,
const device T * biases_,
const int src_ld_,
threadgroup T * dst_,
ushort simd_group_id,
ushort simd_lane_id )
inline

Member Function Documentation

◆ load_safe()

template<typename T , short BROWS, short BCOLS, short dst_ld, short reduction_dim, short tgp_size, short group_size, short bits>
void QuantizedBlockLoader< T, BROWS, BCOLS, dst_ld, reduction_dim, tgp_size, group_size, bits >::load_safe ( short2 src_tile_dim) const
inline

◆ load_unsafe()

template<typename T , short BROWS, short BCOLS, short dst_ld, short reduction_dim, short tgp_size, short group_size, short bits>
void QuantizedBlockLoader< T, BROWS, BCOLS, dst_ld, reduction_dim, tgp_size, group_size, bits >::load_unsafe ( ) const
inline

◆ next()

template<typename T , short BROWS, short BCOLS, short dst_ld, short reduction_dim, short tgp_size, short group_size, short bits>
void QuantizedBlockLoader< T, BROWS, BCOLS, dst_ld, reduction_dim, tgp_size, group_size, bits >::next ( )
inline

Member Data Documentation

◆ BCOLS_PACKED

template<typename T , short BROWS, short BCOLS, short dst_ld, short reduction_dim, short tgp_size, short group_size, short bits>
constant constexpr const short QuantizedBlockLoader< T, BROWS, BCOLS, dst_ld, reduction_dim, tgp_size, group_size, bits >::BCOLS_PACKED = BCOLS / pack_factor
staticconstexpr

◆ bi

template<typename T , short BROWS, short BCOLS, short dst_ld, short reduction_dim, short tgp_size, short group_size, short bits>
const short QuantizedBlockLoader< T, BROWS, BCOLS, dst_ld, reduction_dim, tgp_size, group_size, bits >::bi

◆ biases

template<typename T , short BROWS, short BCOLS, short dst_ld, short reduction_dim, short tgp_size, short group_size, short bits>
const device T* QuantizedBlockLoader< T, BROWS, BCOLS, dst_ld, reduction_dim, tgp_size, group_size, bits >::biases

◆ bj

template<typename T , short BROWS, short BCOLS, short dst_ld, short reduction_dim, short tgp_size, short group_size, short bits>
const short QuantizedBlockLoader< T, BROWS, BCOLS, dst_ld, reduction_dim, tgp_size, group_size, bits >::bj

◆ dst

template<typename T , short BROWS, short BCOLS, short dst_ld, short reduction_dim, short tgp_size, short group_size, short bits>
threadgroup T* QuantizedBlockLoader< T, BROWS, BCOLS, dst_ld, reduction_dim, tgp_size, group_size, bits >::dst

◆ group_step_cnt

template<typename T , short BROWS, short BCOLS, short dst_ld, short reduction_dim, short tgp_size, short group_size, short bits>
short QuantizedBlockLoader< T, BROWS, BCOLS, dst_ld, reduction_dim, tgp_size, group_size, bits >::group_step_cnt

◆ group_steps

template<typename T , short BROWS, short BCOLS, short dst_ld, short reduction_dim, short tgp_size, short group_size, short bits>
constant constexpr const short QuantizedBlockLoader< T, BROWS, BCOLS, dst_ld, reduction_dim, tgp_size, group_size, bits >::group_steps = group_size / BCOLS
staticconstexpr

◆ group_stride

template<typename T , short BROWS, short BCOLS, short dst_ld, short reduction_dim, short tgp_size, short group_size, short bits>
const int QuantizedBlockLoader< T, BROWS, BCOLS, dst_ld, reduction_dim, tgp_size, group_size, bits >::group_stride

◆ n_reads

template<typename T , short BROWS, short BCOLS, short dst_ld, short reduction_dim, short tgp_size, short group_size, short bits>
constant constexpr const short QuantizedBlockLoader< T, BROWS, BCOLS, dst_ld, reduction_dim, tgp_size, group_size, bits >::n_reads
staticconstexpr
Initial value:
=
(BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size
static constant constexpr const short BCOLS_PACKED
Definition quantized.h:273

◆ pack_factor

template<typename T , short BROWS, short BCOLS, short dst_ld, short reduction_dim, short tgp_size, short group_size, short bits>
constant constexpr const short QuantizedBlockLoader< T, BROWS, BCOLS, dst_ld, reduction_dim, tgp_size, group_size, bits >::pack_factor = 32 / bits
staticconstexpr

◆ scales

template<typename T , short BROWS, short BCOLS, short dst_ld, short reduction_dim, short tgp_size, short group_size, short bits>
const device T* QuantizedBlockLoader< T, BROWS, BCOLS, dst_ld, reduction_dim, tgp_size, group_size, bits >::scales

◆ src

template<typename T , short BROWS, short BCOLS, short dst_ld, short reduction_dim, short tgp_size, short group_size, short bits>
const device uint32_t* QuantizedBlockLoader< T, BROWS, BCOLS, dst_ld, reduction_dim, tgp_size, group_size, bits >::src

◆ src_ld

template<typename T , short BROWS, short BCOLS, short dst_ld, short reduction_dim, short tgp_size, short group_size, short bits>
const int QuantizedBlockLoader< T, BROWS, BCOLS, dst_ld, reduction_dim, tgp_size, group_size, bits >::src_ld

◆ thread_idx

template<typename T , short BROWS, short BCOLS, short dst_ld, short reduction_dim, short tgp_size, short group_size, short bits>
const short QuantizedBlockLoader< T, BROWS, BCOLS, dst_ld, reduction_dim, tgp_size, group_size, bits >::thread_idx

◆ tile_stride

template<typename T , short BROWS, short BCOLS, short dst_ld, short reduction_dim, short tgp_size, short group_size, short bits>
const int QuantizedBlockLoader< T, BROWS, BCOLS, dst_ld, reduction_dim, tgp_size, group_size, bits >::tile_stride

The documentation for this struct was generated from the following file: