MLX
Loading...
Searching...
No Matches
Classes | Macros | Functions | Variables
quantized.h File Reference
#include <metal_simdgroup>
#include <metal_stdlib>

Go to the source code of this file.

Classes

struct  QuantizedBlockLoader< T, BROWS, BCOLS, dst_ld, reduction_dim, tgp_size, group_size, bits >
 

Macros

#define MLX_MTL_CONST   static constant constexpr const
 

Functions

template<typename T , typename U , int values_per_thread, int bits>
load_vector (const device T *x, thread U *x_thread)
 
template<typename T , typename U , int values_per_thread, int bits>
load_vector_safe (const device T *x, thread U *x_thread, int N)
 
template<typename U , int values_per_thread, int bits>
qdot (const device uint8_t *w, const thread U *x_thread, U scale, U bias, U sum)
 
template<typename U , int values_per_thread, int bits>
qdot_safe (const device uint8_t *w, const thread U *x_thread, U scale, U bias, U sum, int N)
 
template<typename U , int values_per_thread, int bits>
void qouter (const thread uint8_t *w, U x, U scale, U bias, thread U *result)
 
template<typename U , int N, int bits>
void dequantize (const device uint8_t *w, U scale, U bias, threadgroup U *w_local)
 
template<typename T , int group_size, int bits>
METAL_FUNC void qmv_fast_impl (const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
 
template<typename T , int group_size, int bits>
METAL_FUNC void qmv_impl (const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
 
template<typename T , const int group_size, const int bits>
METAL_FUNC void qvm_impl (const device T *x, const device uint32_t *w, const device T *scales, const device T *biases, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
 
template<typename T , const int BM, const int BK, const int BN, const int group_size, const int bits, const bool aligned_N>
METAL_FUNC void qmm_t_impl (const device T *x, const device uint32_t *w, const device T *scales, const device T *biases, device T *y, threadgroup T *Xs, threadgroup T *Ws, const constant int &M, const constant int &N, const constant int &K, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
 
template<typename T , const int BM, const int BK, const int BN, const int group_size, const int bits>
METAL_FUNC void qmm_n_impl (const device T *x, const device uint32_t *w, const device T *scales, const device T *biases, device T *y, threadgroup T *Xs, threadgroup T *Ws, const constant int &M, const constant int &N, const constant int &K, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
 
template<typename T >
METAL_FUNC void adjust_matrix_offsets (const device T *&x, const device uint32_t *&w, const device T *&scales, const device T *&biases, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, device T *&y, int output_stride, const constant int &batch_ndims, const constant int *batch_shape, const constant size_t *lhs_strides, const constant size_t *rhs_strides, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid)
 
template<typename T , int group_size, int bits>
void qmv_fast (const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
 
template<typename T , const int group_size, const int bits>
void qmv (const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
 
template<typename T , const int group_size, const int bits>
void qvm (const device T *x, const device uint32_t *w, const device T *scales, const device T *biases, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
 
template<typename T , const int group_size, const int bits, const bool aligned_N, const int BM = 32, const int BK = 32, const int BN = 32>
void qmm_t (const device T *x, const device uint32_t *w, const device T *scales, const device T *biases, device T *y, const constant int &M, const constant int &N, const constant int &K, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
 
template<typename T , const int group_size, const int bits, const int BM = 32, const int BK = 32, const int BN = 32>
void qmm_n (const device T *x, const device uint32_t *w, const device T *scales, const device T *biases, device T *y, const constant int &M, const constant int &N, const constant int &K, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
 
template<typename T , int group_size, int bits>
void bs_qmv_fast (const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &batch_ndims, const constant int *batch_shape, const constant size_t *lhs_strides, const constant size_t *rhs_strides, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint simd_gid, uint simd_lid)
 
template<typename T , int group_size, int bits>
void bs_qmv (const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &batch_ndims, const constant int *batch_shape, const constant size_t *lhs_strides, const constant size_t *rhs_strides, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint simd_gid, uint simd_lid)
 
template<typename T , int group_size, int bits>
void bs_qvm (const device T *x, const device uint32_t *w, const device T *scales, const device T *biases, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &batch_ndims, const constant int *batch_shape, const constant size_t *lhs_strides, const constant size_t *rhs_strides, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint simd_gid, uint simd_lid)
 
template<typename T , const int group_size, const int bits, const bool aligned_N, const int BM = 32, const int BK = 32, const int BN = 32>
void bs_qmm_t (const device T *x, const device uint32_t *w, const device T *scales, const device T *biases, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, device T *y, const constant int &M, const constant int &N, const constant int &K, const constant int &batch_ndims, const constant int *batch_shape, const constant size_t *lhs_strides, const constant size_t *rhs_strides, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
 
template<typename T , const int group_size, const int bits, const int BM = 32, const int BK = 32, const int BN = 32>
void bs_qmm_n (const device T *x, const device uint32_t *w, const device T *scales, const device T *biases, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, device T *y, const constant int &M, const constant int &N, const constant int &K, const constant int &batch_ndims, const constant int *batch_shape, const constant size_t *lhs_strides, const constant size_t *rhs_strides, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
 

Variables

static constant constexpr const int SIMD_SIZE = 32
 

Macro Definition Documentation

◆ MLX_MTL_CONST

#define MLX_MTL_CONST   static constant constexpr const

Function Documentation

◆ adjust_matrix_offsets()

template<typename T >
METAL_FUNC void adjust_matrix_offsets ( const device T *& x,
const device uint32_t *& w,
const device T *& scales,
const device T *& biases,
const device uint32_t * lhs_indices,
const device uint32_t * rhs_indices,
device T *& y,
int output_stride,
const constant int & batch_ndims,
const constant int * batch_shape,
const constant size_t * lhs_strides,
const constant size_t * rhs_strides,
const constant int & x_batch_ndims,
const constant int * x_shape,
const constant size_t * x_strides,
const constant int & w_batch_ndims,
const constant int * w_shape,
const constant size_t * w_strides,
const constant size_t * s_strides,
const constant size_t * b_strides,
uint3 tid )

◆ bs_qmm_n()

template<typename T , const int group_size, const int bits, const int BM = 32, const int BK = 32, const int BN = 32>
void bs_qmm_n ( const device T * x,
const device uint32_t * w,
const device T * scales,
const device T * biases,
const device uint32_t * lhs_indices,
const device uint32_t * rhs_indices,
device T * y,
const constant int & M,
const constant int & N,
const constant int & K,
const constant int & batch_ndims,
const constant int * batch_shape,
const constant size_t * lhs_strides,
const constant size_t * rhs_strides,
const constant int & x_batch_ndims,
const constant int * x_shape,
const constant size_t * x_strides,
const constant int & w_batch_ndims,
const constant int * w_shape,
const constant size_t * w_strides,
const constant size_t * s_strides,
const constant size_t * b_strides,
uint3 tid,
uint lid,
uint simd_gid,
uint simd_lid )

◆ bs_qmm_t()

template<typename T , const int group_size, const int bits, const bool aligned_N, const int BM = 32, const int BK = 32, const int BN = 32>
void bs_qmm_t ( const device T * x,
const device uint32_t * w,
const device T * scales,
const device T * biases,
const device uint32_t * lhs_indices,
const device uint32_t * rhs_indices,
device T * y,
const constant int & M,
const constant int & N,
const constant int & K,
const constant int & batch_ndims,
const constant int * batch_shape,
const constant size_t * lhs_strides,
const constant size_t * rhs_strides,
const constant int & x_batch_ndims,
const constant int * x_shape,
const constant size_t * x_strides,
const constant int & w_batch_ndims,
const constant int * w_shape,
const constant size_t * w_strides,
const constant size_t * s_strides,
const constant size_t * b_strides,
uint3 tid,
uint lid,
uint simd_gid,
uint simd_lid )

◆ bs_qmv()

template<typename T , int group_size, int bits>
void bs_qmv ( const device uint32_t * w,
const device T * scales,
const device T * biases,
const device T * x,
const device uint32_t * lhs_indices,
const device uint32_t * rhs_indices,
device T * y,
const constant int & in_vec_size,
const constant int & out_vec_size,
const constant int & batch_ndims,
const constant int * batch_shape,
const constant size_t * lhs_strides,
const constant size_t * rhs_strides,
const constant int & x_batch_ndims,
const constant int * x_shape,
const constant size_t * x_strides,
const constant int & w_batch_ndims,
const constant int * w_shape,
const constant size_t * w_strides,
const constant size_t * s_strides,
const constant size_t * b_strides,
uint3 tid,
uint simd_gid,
uint simd_lid )

◆ bs_qmv_fast()

template<typename T , int group_size, int bits>
void bs_qmv_fast ( const device uint32_t * w,
const device T * scales,
const device T * biases,
const device T * x,
const device uint32_t * lhs_indices,
const device uint32_t * rhs_indices,
device T * y,
const constant int & in_vec_size,
const constant int & out_vec_size,
const constant int & batch_ndims,
const constant int * batch_shape,
const constant size_t * lhs_strides,
const constant size_t * rhs_strides,
const constant int & x_batch_ndims,
const constant int * x_shape,
const constant size_t * x_strides,
const constant int & w_batch_ndims,
const constant int * w_shape,
const constant size_t * w_strides,
const constant size_t * s_strides,
const constant size_t * b_strides,
uint3 tid,
uint simd_gid,
uint simd_lid )

◆ bs_qvm()

template<typename T , int group_size, int bits>
void bs_qvm ( const device T * x,
const device uint32_t * w,
const device T * scales,
const device T * biases,
const device uint32_t * lhs_indices,
const device uint32_t * rhs_indices,
device T * y,
const constant int & in_vec_size,
const constant int & out_vec_size,
const constant int & batch_ndims,
const constant int * batch_shape,
const constant size_t * lhs_strides,
const constant size_t * rhs_strides,
const constant int & x_batch_ndims,
const constant int * x_shape,
const constant size_t * x_strides,
const constant int & w_batch_ndims,
const constant int * w_shape,
const constant size_t * w_strides,
const constant size_t * s_strides,
const constant size_t * b_strides,
uint3 tid,
uint simd_gid,
uint simd_lid )

◆ dequantize()

template<typename U , int N, int bits>
void dequantize ( const device uint8_t * w,
U scale,
U bias,
threadgroup U * w_local )
inline

◆ load_vector()

template<typename T , typename U , int values_per_thread, int bits>
U load_vector ( const device T * x,
thread U * x_thread )
inline

◆ load_vector_safe()

template<typename T , typename U , int values_per_thread, int bits>
U load_vector_safe ( const device T * x,
thread U * x_thread,
int N )
inline

◆ qdot()

template<typename U , int values_per_thread, int bits>
U qdot ( const device uint8_t * w,
const thread U * x_thread,
U scale,
U bias,
U sum )
inline

◆ qdot_safe()

template<typename U , int values_per_thread, int bits>
U qdot_safe ( const device uint8_t * w,
const thread U * x_thread,
U scale,
U bias,
U sum,
int N )
inline

◆ qmm_n()

template<typename T , const int group_size, const int bits, const int BM = 32, const int BK = 32, const int BN = 32>
void qmm_n ( const device T * x,
const device uint32_t * w,
const device T * scales,
const device T * biases,
device T * y,
const constant int & M,
const constant int & N,
const constant int & K,
uint3 tid,
uint lid,
uint simd_gid,
uint simd_lid )

◆ qmm_n_impl()

template<typename T , const int BM, const int BK, const int BN, const int group_size, const int bits>
METAL_FUNC void qmm_n_impl ( const device T * x,
const device uint32_t * w,
const device T * scales,
const device T * biases,
device T * y,
threadgroup T * Xs,
threadgroup T * Ws,
const constant int & M,
const constant int & N,
const constant int & K,
uint3 tid,
uint lid,
uint simd_gid,
uint simd_lid )

◆ qmm_t()

template<typename T , const int group_size, const int bits, const bool aligned_N, const int BM = 32, const int BK = 32, const int BN = 32>
void qmm_t ( const device T * x,
const device uint32_t * w,
const device T * scales,
const device T * biases,
device T * y,
const constant int & M,
const constant int & N,
const constant int & K,
uint3 tid,
uint lid,
uint simd_gid,
uint simd_lid )

◆ qmm_t_impl()

template<typename T , const int BM, const int BK, const int BN, const int group_size, const int bits, const bool aligned_N>
METAL_FUNC void qmm_t_impl ( const device T * x,
const device uint32_t * w,
const device T * scales,
const device T * biases,
device T * y,
threadgroup T * Xs,
threadgroup T * Ws,
const constant int & M,
const constant int & N,
const constant int & K,
uint3 tid,
uint lid,
uint simd_gid,
uint simd_lid )

◆ qmv()

template<typename T , const int group_size, const int bits>
void qmv ( const device uint32_t * w,
const device T * scales,
const device T * biases,
const device T * x,
device T * y,
const constant int & in_vec_size,
const constant int & out_vec_size,
uint3 tid,
uint simd_gid,
uint simd_lid )

◆ qmv_fast()

template<typename T , int group_size, int bits>
void qmv_fast ( const device uint32_t * w,
const device T * scales,
const device T * biases,
const device T * x,
device T * y,
const constant int & in_vec_size,
const constant int & out_vec_size,
uint3 tid,
uint simd_gid,
uint simd_lid )

◆ qmv_fast_impl()

template<typename T , int group_size, int bits>
METAL_FUNC void qmv_fast_impl ( const device uint32_t * w,
const device T * scales,
const device T * biases,
const device T * x,
device T * y,
const constant int & in_vec_size,
const constant int & out_vec_size,
uint3 tid,
uint simd_gid,
uint simd_lid )

◆ qmv_impl()

template<typename T , int group_size, int bits>
METAL_FUNC void qmv_impl ( const device uint32_t * w,
const device T * scales,
const device T * biases,
const device T * x,
device T * y,
const constant int & in_vec_size,
const constant int & out_vec_size,
uint3 tid,
uint simd_gid,
uint simd_lid )

◆ qouter()

template<typename U , int values_per_thread, int bits>
void qouter ( const thread uint8_t * w,
U x,
U scale,
U bias,
thread U * result )
inline

◆ qvm()

template<typename T , const int group_size, const int bits>
void qvm ( const device T * x,
const device uint32_t * w,
const device T * scales,
const device T * biases,
device T * y,
const constant int & in_vec_size,
const constant int & out_vec_size,
uint3 tid,
uint simd_gid,
uint simd_lid )

◆ qvm_impl()

template<typename T , const int group_size, const int bits>
METAL_FUNC void qvm_impl ( const device T * x,
const device uint32_t * w,
const device T * scales,
const device T * biases,
device T * y,
const constant int & in_vec_size,
const constant int & out_vec_size,
uint3 tid,
uint simd_gid,
uint simd_lid )

Variable Documentation

◆ SIMD_SIZE

constant constexpr const int SIMD_SIZE = 32
staticconstexpr