MLX
Loading...
Searching...
No Matches
steel_conv.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3constexpr std::string_view steel_conv_kernels = R"(
4template [[host_name("{name}")]] [[kernel]] void
5implicit_gemm_conv_2d<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}, {n_channels}, {small_filter}>(
6 const device {itype}* A [[buffer(0)]],
7 const device {itype}* B [[buffer(1)]],
8 device {itype}* C [[buffer(2)]],
9 const constant MLXConvParams<2>* params [[buffer(3)]],
10 const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
11 uint3 tid [[threadgroup_position_in_grid]],
12 uint3 lid [[thread_position_in_threadgroup]],
13 uint simd_gid [[simdgroup_index_in_threadgroup]],
14 uint simd_lid [[thread_index_in_simdgroup]]);
15)";
16
17constexpr std::string_view steel_conv_general_kernels = R"(
18template [[host_name("{name}")]] [[kernel]] void
19 implicit_gemm_conv_2d_general<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}>(
20 const device {itype}* A [[buffer(0)]],
21 const device {itype}* B [[buffer(1)]],
22 device {itype}* C [[buffer(2)]],
23 const constant MLXConvParams<2>* params [[buffer(3)]],
24 const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
25 const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]],
26 const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]],
27 const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]],
28 uint3 tid [[threadgroup_position_in_grid]],
29 uint3 lid [[thread_position_in_threadgroup]],
30 uint simd_gid [[simdgroup_index_in_threadgroup]],
31 uint simd_lid [[thread_index_in_simdgroup]]);
32)";
constexpr std::string_view steel_conv_kernels
Definition steel_conv.h:3
constexpr std::string_view steel_conv_general_kernels
Definition steel_conv.h:17