MLX
 
Loading...
Searching...
No Matches
steel_gemm_fused.h File Reference

Go to the source code of this file.

Functions

template<typename T, int BM, int BN, int BK, int WM, int WN, bool transpose_a, bool transpose_b, typename AccumType = float>
void gemm (const device T *A, const device T *B, const device T *C, device T *D, const constant GEMMParams *params, const constant GEMMAddMMParams *addmm_params, const constant int *batch_shape, const constant int64_t *batch_strides, const constant uint32_t *lhs_indices, const constant uint32_t *rhs_indices, const constant uint32_t *C_indices, const constant int *operand_shape, const constant int64_t *operand_strides, const constant packed_int3 &operand_batch_ndim, uint simd_lane_id, uint simd_group_id, uint3 tid, uint3 lid)
 

Variables

constant bool has_batch
 
constant bool use_out_source
 
constant bool do_axpby
 
constant bool align_M
 
constant bool align_N
 
constant bool align_K
 
constant bool do_gather
 
constant bool gather_bias = do_gather && use_out_source
 

Function Documentation

◆ gemm()

template<typename T, int BM, int BN, int BK, int WM, int WN, bool transpose_a, bool transpose_b, typename AccumType = float>
void gemm ( const device T * A,
const device T * B,
const device T * C,
device T * D,
const constant GEMMParams * params,
const constant GEMMAddMMParams * addmm_params,
const constant int * batch_shape,
const constant int64_t * batch_strides,
const constant uint32_t * lhs_indices,
const constant uint32_t * rhs_indices,
const constant uint32_t * C_indices,
const constant int * operand_shape,
const constant int64_t * operand_strides,
const constant packed_int3 & operand_batch_ndim,
uint simd_lane_id,
uint simd_group_id,
uint3 tid,
uint3 lid )

Variable Documentation

◆ align_K

constant bool align_K

◆ align_M

constant bool align_M

◆ align_N

constant bool align_N

◆ do_axpby

constant bool do_axpby

◆ do_gather

constant bool do_gather

◆ gather_bias

constant bool gather_bias = do_gather && use_out_source

◆ has_batch

constant bool has_batch

◆ use_out_source

constant bool use_out_source