4template [[host_name("{name}")]] [[kernel]] void 
    5implicit_gemm_conv_2d<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}, {n_channels}, {small_filter}>( 
    6    const device {itype}* A [[buffer(0)]], 
    7    const device {itype}* B [[buffer(1)]], 
    8    device {itype}* C [[buffer(2)]], 
    9    const constant MLXConvParams<2>* params [[buffer(3)]], 
   10    const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], 
   11    uint3 tid [[threadgroup_position_in_grid]], 
   12    uint3 lid [[thread_position_in_threadgroup]], 
   13    uint simd_gid [[simdgroup_index_in_threadgroup]], 
   14    uint simd_lid [[thread_index_in_simdgroup]]); 
   18template [[host_name("{name}")]] [[kernel]] void 
   19    implicit_gemm_conv_2d_general<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}>( 
   20        const device {itype}* A [[buffer(0)]], 
   21        const device {itype}* B [[buffer(1)]], 
   22        device {itype}* C [[buffer(2)]], 
   23        const constant MLXConvParams<2>* params [[buffer(3)]], 
   24        const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], 
   25        const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]], 
   26        const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]], 
   27        const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]], 
   28        uint3 tid [[threadgroup_position_in_grid]], 
   29        uint3 lid [[thread_position_in_threadgroup]], 
   30        uint simd_gid [[simdgroup_index_in_threadgroup]], 
   31        uint simd_lid [[thread_index_in_simdgroup]]); 
constexpr std::string_view steel_conv_kernels
Definition steel_conv.h:3
 
constexpr std::string_view steel_conv_general_kernels
Definition steel_conv.h:17