MLX
 
Loading...
Searching...
No Matches
steel_conv_general.h File Reference

Go to the source code of this file.

Functions

template<typename T, int BM, int BN, int BK, int WM, int WN, typename AccumType = float, typename Epilogue = TransformNone<T, AccumType>>
void implicit_gemm_conv_2d_general (const device T *A, const device T *B, device T *C, const constant MLXConvParams< 2 > *params, const constant ImplicitGemmConv2DParams *gemm_params, const constant Conv2DGeneralJumpParams *jump_params, const constant Conv2DGeneralBaseInfo *base_h, const constant Conv2DGeneralBaseInfo *base_w, uint3 tid, uint3 lid, uint simd_gid, uint simd_lid)
 

Function Documentation

◆ implicit_gemm_conv_2d_general()

template<typename T, int BM, int BN, int BK, int WM, int WN, typename AccumType = float, typename Epilogue = TransformNone<T, AccumType>>
void implicit_gemm_conv_2d_general ( const device T * A,
const device T * B,
device T * C,
const constant MLXConvParams< 2 > * params,
const constant ImplicitGemmConv2DParams * gemm_params,
const constant Conv2DGeneralJumpParams * jump_params,
const constant Conv2DGeneralBaseInfo * base_h,
const constant Conv2DGeneralBaseInfo * base_w,
uint3 tid,
uint3 lid,
uint simd_gid,
uint simd_lid )