Go to the source code of this file.
|
template<typename T, int BM, int BN, int BK, int WM, int WN, typename AccumType = float, typename Epilogue = TransformNone<T, AccumType>> |
void | implicit_gemm_conv_2d_general (const device T *A, const device T *B, device T *C, const constant MLXConvParams< 2 > *params, const constant ImplicitGemmConv2DParams *gemm_params, const constant Conv2DGeneralJumpParams *jump_params, const constant Conv2DGeneralBaseInfo *base_h, const constant Conv2DGeneralBaseInfo *base_w, uint3 tid, uint3 lid, uint simd_gid, uint simd_lid) |
|
◆ implicit_gemm_conv_2d_general()
template<typename T, int BM, int BN, int BK, int WM, int WN, typename AccumType = float, typename Epilogue = TransformNone<T, AccumType>>
void implicit_gemm_conv_2d_general |
( |
const device T * | A, |
|
|
const device T * | B, |
|
|
device T * | C, |
|
|
const constant MLXConvParams< 2 > * | params, |
|
|
const constant ImplicitGemmConv2DParams * | gemm_params, |
|
|
const constant Conv2DGeneralJumpParams * | jump_params, |
|
|
const constant Conv2DGeneralBaseInfo * | base_h, |
|
|
const constant Conv2DGeneralBaseInfo * | base_w, |
|
|
uint3 | tid, |
|
|
uint3 | lid, |
|
|
uint | simd_gid, |
|
|
uint | simd_lid ) |