MLX
 
Loading...
Searching...
No Matches
steel_gemm_splitk.h File Reference

Go to the source code of this file.

Functions

template<typename T, typename U, int BM, int BN, int BK, int WM, int WN, bool transpose_a, bool transpose_b, bool MN_aligned, bool K_aligned>
void gemm_splitk (const device T *A, const device T *B, device U *C, const constant GEMMSpiltKParams *params, uint simd_lane_id, uint simd_group_id, uint3 tid, uint3 lid)
 
template<typename AccT, typename OutT, typename Epilogue = TransformNone<OutT, AccT>>
void gemm_splitk_accum (const device AccT *C_split, device OutT *D, const constant int &k_partitions, const constant int &partition_stride, const constant int &ldd, uint2 gid)
 
template<typename AccT, typename OutT, typename Epilogue = TransformAxpby<OutT, AccT>>
void gemm_splitk_accum_axpby (const device AccT *C_split, device OutT *D, const constant int &k_partitions, const constant int &partition_stride, const constant int &ldd, const device OutT *C, const constant int &ldc, const constant int &fdc, const constant float &alpha, const constant float &beta, uint2 gid)
 

Function Documentation

◆ gemm_splitk()

template<typename T, typename U, int BM, int BN, int BK, int WM, int WN, bool transpose_a, bool transpose_b, bool MN_aligned, bool K_aligned>
void gemm_splitk ( const device T * A,
const device T * B,
device U * C,
const constant GEMMSpiltKParams * params,
uint simd_lane_id,
uint simd_group_id,
uint3 tid,
uint3 lid )

◆ gemm_splitk_accum()

template<typename AccT, typename OutT, typename Epilogue = TransformNone<OutT, AccT>>
void gemm_splitk_accum ( const device AccT * C_split,
device OutT * D,
const constant int & k_partitions,
const constant int & partition_stride,
const constant int & ldd,
uint2 gid )

◆ gemm_splitk_accum_axpby()

template<typename AccT, typename OutT, typename Epilogue = TransformAxpby<OutT, AccT>>
void gemm_splitk_accum_axpby ( const device AccT * C_split,
device OutT * D,
const constant int & k_partitions,
const constant int & partition_stride,
const constant int & ldd,
const device OutT * C,
const constant int & ldc,
const constant int & fdc,
const constant float & alpha,
const constant float & beta,
uint2 gid )