MLX
 
Loading...
Searching...
No Matches
gemv_masked.h File Reference

Go to the source code of this file.

Classes

struct  _NoMask
 
struct  ScaleOp< OutT, InT >
 
struct  GEMVKernel< T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN >
 
struct  GEMVTKernel< T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN >
 Vector matrix multiplication. More...
 

Macros

#define MLX_MTL_CONST   static constant constexpr const
 
#define MLX_MTL_PRAGMA_UNROLL   _Pragma("clang loop unroll(full)")
 

Typedefs

typedef struct _NoMask nomask_t
 

Functions

template<typename T, typename out_mask_t, typename op_mask_t, const int BM, const int BN, const int SM, const int SN, const int TM, const int TN, const bool kDoNCBatch>
void gemv_masked (const device T *mat, const device T *in_vec, device T *out_vec, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &marix_ld, const constant int &batch_ndim, const constant int *batch_shape, const constant int64_t *vector_batch_stride, const constant int64_t *matrix_batch_stride, const device out_mask_t *out_mask, const device op_mask_t *mat_mask, const device op_mask_t *vec_mask, const constant int *mask_strides, const constant int64_t *mask_batch_strides, uint3 tid, uint3 lid, uint simd_gid, uint simd_lid)
 Matrix vector multiplication.
 
template<typename T, typename out_mask_t, typename op_mask_t, const int BM, const int BN, const int SM, const int SN, const int TM, const int TN, const bool kDoNCBatch>
void gemv_t_masked (const device T *mat, const device T *in_vec, device T *out_vec, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &marix_ld, const constant int &batch_ndim, const constant int *batch_shape, const constant int64_t *vector_batch_stride, const constant int64_t *matrix_batch_stride, const device out_mask_t *out_mask, const device op_mask_t *mat_mask, const device op_mask_t *vec_mask, const constant int *mask_strides, const constant int64_t *mask_batch_strides, uint3 tid, uint3 lid, uint simd_gid, uint simd_lid)
 Vector matrix multiplication.
 

Macro Definition Documentation

◆ MLX_MTL_CONST

#define MLX_MTL_CONST   static constant constexpr const

◆ MLX_MTL_PRAGMA_UNROLL

#define MLX_MTL_PRAGMA_UNROLL   _Pragma("clang loop unroll(full)")

Typedef Documentation

◆ nomask_t

typedef struct _NoMask nomask_t

Function Documentation

◆ gemv_masked()

template<typename T, typename out_mask_t, typename op_mask_t, const int BM, const int BN, const int SM, const int SN, const int TM, const int TN, const bool kDoNCBatch>
void gemv_masked ( const device T * mat,
const device T * in_vec,
device T * out_vec,
const constant int & in_vec_size,
const constant int & out_vec_size,
const constant int & marix_ld,
const constant int & batch_ndim,
const constant int * batch_shape,
const constant int64_t * vector_batch_stride,
const constant int64_t * matrix_batch_stride,
const device out_mask_t * out_mask,
const device op_mask_t * mat_mask,
const device op_mask_t * vec_mask,
const constant int * mask_strides,
const constant int64_t * mask_batch_strides,
uint3 tid,
uint3 lid,
uint simd_gid,
uint simd_lid )

Matrix vector multiplication.

◆ gemv_t_masked()

template<typename T, typename out_mask_t, typename op_mask_t, const int BM, const int BN, const int SM, const int SN, const int TM, const int TN, const bool kDoNCBatch>
void gemv_t_masked ( const device T * mat,
const device T * in_vec,
device T * out_vec,
const constant int & in_vec_size,
const constant int & out_vec_size,
const constant int & marix_ld,
const constant int & batch_ndim,
const constant int * batch_shape,
const constant int64_t * vector_batch_stride,
const constant int64_t * matrix_batch_stride,
const device out_mask_t * out_mask,
const device op_mask_t * mat_mask,
const device op_mask_t * vec_mask,
const constant int * mask_strides,
const constant int64_t * mask_batch_strides,
uint3 tid,
uint3 lid,
uint simd_gid,
uint simd_lid )

Vector matrix multiplication.