MLX
Loading...
Searching...
No Matches
Classes | Typedefs | Functions
steel_gemm_masked.h File Reference
#include "mlx/backend/metal/kernels/steel/defines.h"

Go to the source code of this file.

Classes

struct  _NoMask
 
struct  ScaleOp< OutT, InT >
 

Typedefs

typedef struct _NoMask nomask_t
 

Functions

template<typename T , typename out_mask_t , typename op_mask_t , int BM, int BN, int BK, int WM, int WN, bool transpose_a, bool transpose_b, bool MN_aligned, bool K_aligned>
void block_masked_gemm (const device T *A, const device T *B, device T *D, const constant GEMMParams *params, const constant int *batch_shape, const constant size_t *batch_strides, const device out_mask_t *out_mask, const device op_mask_t *lhs_mask, const device op_mask_t *rhs_mask, const constant int *mask_strides, uint simd_lane_id, uint simd_group_id, uint3 tid, uint3 lid)
 
template<typename T , int BM, int BN, int BK, int WM, int WN, bool transpose_a, bool transpose_b, bool MN_aligned, bool K_aligned, bool has_operand_mask = false>
void block_masked_gemm (const device T *A, const device T *B, device T *D, const constant GEMMParams *params, const constant int *batch_shape, const constant size_t *batch_strides, const device bool *out_mask, const device bool *lhs_mask, const device bool *rhs_mask, const constant int *mask_strides, uint simd_lane_id, uint simd_group_id, uint3 tid, uint3 lid)
 

Typedef Documentation

◆ nomask_t

typedef struct _NoMask nomask_t

Function Documentation

◆ block_masked_gemm() [1/2]

template<typename T , int BM, int BN, int BK, int WM, int WN, bool transpose_a, bool transpose_b, bool MN_aligned, bool K_aligned, bool has_operand_mask = false>
void block_masked_gemm ( const device T * A,
const device T * B,
device T * D,
const constant GEMMParams * params,
const constant int * batch_shape,
const constant size_t * batch_strides,
const device bool * out_mask,
const device bool * lhs_mask,
const device bool * rhs_mask,
const constant int * mask_strides,
uint simd_lane_id,
uint simd_group_id,
uint3 tid,
uint3 lid )

◆ block_masked_gemm() [2/2]

template<typename T , typename out_mask_t , typename op_mask_t , int BM, int BN, int BK, int WM, int WN, bool transpose_a, bool transpose_b, bool MN_aligned, bool K_aligned>
void block_masked_gemm ( const device T * A,
const device T * B,
device T * D,
const constant GEMMParams * params,
const constant int * batch_shape,
const constant size_t * batch_strides,
const device out_mask_t * out_mask,
const device op_mask_t * lhs_mask,
const device op_mask_t * rhs_mask,
const constant int * mask_strides,
uint simd_lane_id,
uint simd_group_id,
uint3 tid,
uint3 lid )