4template [[host_name("{name}")]] [[kernel]] void
5implicit_gemm_conv_2d<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}, {n_channels}, {small_filter}>(
6 const device {itype}* A [[buffer(0)]],
7 const device {itype}* B [[buffer(1)]],
8 device {itype}* C [[buffer(2)]],
9 const constant MLXConvParams<2>* params [[buffer(3)]],
10 const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
11 uint3 tid [[threadgroup_position_in_grid]],
12 uint3 lid [[thread_position_in_threadgroup]],
13 uint simd_gid [[simdgroup_index_in_threadgroup]],
14 uint simd_lid [[thread_index_in_simdgroup]]);
18template [[host_name("{name}")]] [[kernel]] void
19 implicit_gemm_conv_2d_general<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}>(
20 const device {itype}* A [[buffer(0)]],
21 const device {itype}* B [[buffer(1)]],
22 device {itype}* C [[buffer(2)]],
23 const constant MLXConvParams<2>* params [[buffer(3)]],
24 const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
25 const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]],
26 const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]],
27 const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]],
28 uint3 tid [[threadgroup_position_in_grid]],
29 uint3 lid [[thread_position_in_threadgroup]],
30 uint simd_gid [[simdgroup_index_in_threadgroup]],
31 uint simd_lid [[thread_index_in_simdgroup]]);
constexpr std::string_view steel_conv_kernels
Definition steel_conv.h:3
constexpr std::string_view steel_conv_general_kernels
Definition steel_conv.h:17