MLX
 
Loading...
Searching...
No Matches
steel_gemm_masked.h File Reference

Go to the source code of this file.

Classes

struct  _NoMask
 
struct  ScaleOp< OutT, InT >
 

Functions

template<typename T, typename out_mask_t, typename op_mask_t, int BM, int BN, int BK, int WM, int WN, bool transpose_a, bool transpose_b, bool MN_aligned, bool K_aligned>
void block_masked_gemm (const device T *A, const device T *B, device T *D, const constant GEMMParams *params, const constant int *batch_shape, const constant int64_t *batch_strides, const device out_mask_t *out_mask, const device op_mask_t *lhs_mask, const device op_mask_t *rhs_mask, const constant int *mask_strides, uint simd_lane_id, uint simd_group_id, uint3 tid, uint3 lid)
 
template<typename T, int BM, int BN, int BK, int WM, int WN, bool transpose_a, bool transpose_b, bool MN_aligned, bool K_aligned, bool has_operand_mask = false>
void block_masked_gemm (const device T *A, const device T *B, device T *D, const constant GEMMParams *params, const constant int *batch_shape, const constant int64_t *batch_strides, const device bool *out_mask, const device bool *lhs_mask, const device bool *rhs_mask, const constant int *mask_strides, uint simd_lane_id, uint simd_group_id, uint3 tid, uint3 lid)
 

Function Documentation

◆ block_masked_gemm() [1/2]

template<typename T, int BM, int BN, int BK, int WM, int WN, bool transpose_a, bool transpose_b, bool MN_aligned, bool K_aligned, bool has_operand_mask = false>
void block_masked_gemm ( const device T * A,
const device T * B,
device T * D,
const constant GEMMParams * params,
const constant int * batch_shape,
const constant int64_t * batch_strides,
const device bool * out_mask,
const device bool * lhs_mask,
const device bool * rhs_mask,
const constant int * mask_strides,
uint simd_lane_id,
uint simd_group_id,
uint3 tid,
uint3 lid )

◆ block_masked_gemm() [2/2]

template<typename T, typename out_mask_t, typename op_mask_t, int BM, int BN, int BK, int WM, int WN, bool transpose_a, bool transpose_b, bool MN_aligned, bool K_aligned>
void block_masked_gemm ( const device T * A,
const device T * B,
device T * D,
const constant GEMMParams * params,
const constant int * batch_shape,
const constant int64_t * batch_strides,
const device out_mask_t * out_mask,
const device op_mask_t * lhs_mask,
const device op_mask_t * rhs_mask,
const constant int * mask_strides,
uint simd_lane_id,
uint simd_group_id,
uint3 tid,
uint3 lid )