MLX
 
Loading...
Searching...
No Matches
steel_conv.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3#include <metal_stdlib>
4
5using namespace metal;
6
7template <
8 typename T,
9 int BM,
10 int BN,
11 int BK,
12 int WM,
13 int WN,
14 int N_CHANNELS = 0,
15 bool SMALL_FILTER = false>
16[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
18 const device T* A [[buffer(0)]],
19 const device T* B [[buffer(1)]],
20 device T* C [[buffer(2)]],
21 const constant MLXConvParams<2>* params [[buffer(3)]],
22 const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
23 uint3 tid [[threadgroup_position_in_grid]],
24 uint3 lid [[thread_position_in_threadgroup]],
25 uint simd_gid [[simdgroup_index_in_threadgroup]],
26 uint simd_lid [[thread_index_in_simdgroup]]) {
27 using namespace mlx::steel;
28
29 (void)lid;
30
31 constexpr bool transpose_a = false;
32 constexpr bool transpose_b = true;
33 constexpr short tgp_padding_a = 16 / sizeof(T);
34 constexpr short tgp_padding_b = 16 / sizeof(T);
35
36 constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a;
37 constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b;
38 constexpr short shape_a_rows = (transpose_a ? BK : BM);
39 constexpr short shape_b_rows = (transpose_b ? BN : BK);
40 constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows;
41 constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows;
42
43 constexpr short tgp_size = WM * WN * 32;
44
45 // Input loader
46
47 using loader_a_t = typename metal::conditional_t<
48 // Check for small channel specialization
49 N_CHANNELS != 0 && N_CHANNELS <= 4,
50
51 // Go to small channel specialization
53 T,
54 BM,
55 BN,
56 BK,
57 tgp_size,
58 N_CHANNELS,
59 tgp_padding_a>,
60
61 // Else go to general loader
62 typename metal::conditional_t<
63 // Check if filter size is small enough
64 SMALL_FILTER,
65
66 // Go to small filter specialization
68 T,
69 BM,
70 BN,
71 BK,
72 tgp_size,
73 tgp_padding_a>,
74
75 // Else go to large filter generalization
77 T,
78 BM,
79 BN,
80 BK,
81 tgp_size,
82 tgp_padding_a>>>;
83
84 // Weight loader
85 using loader_b_t = typename metal::conditional_t<
86 // Check for small channel specialization
87 N_CHANNELS != 0 && N_CHANNELS <= 4,
88
89 // Go to small channel specialization
91 T,
92 BM,
93 BN,
94 BK,
95 tgp_size,
96 N_CHANNELS,
97 tgp_padding_b>,
98
99 // Else go to general loader
101
102 using mma_t = BlockMMA<
103 T,
104 T,
105 BM,
106 BN,
107 BK,
108 WM,
109 WN,
110 transpose_a,
111 transpose_b,
112 shape_a_cols,
113 shape_b_cols>;
114
115 threadgroup T As[tgp_mem_size_a];
116 threadgroup T Bs[tgp_mem_size_b];
117
118 const int tid_y = ((tid.y) << gemm_params->swizzle_log) +
119 ((tid.x) & ((1 << gemm_params->swizzle_log) - 1));
120 const int tid_x = (tid.x) >> gemm_params->swizzle_log;
121
122 if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) {
123 return;
124 }
125
126 const int c_row = tid_y * BM;
127 const int c_col = tid_x * BN;
128 const int K = gemm_params->K;
129 const int N = gemm_params->N;
130 const int C_per_group = params->C / params->groups;
131
132 // Groups
133 A += tid.z * C_per_group;
134 B += tid.z * N * K;
135 C += tid.z * N;
136
137 B += c_col * K;
138 C += c_row * (N * params->groups) + c_col;
139
140 const int2 offsets_a(0, c_row);
141 const int2 offsets_b(0, c_col);
142
143 // Prepare threadgroup loading operations
144 loader_a_t loader_a(
145 A, As, offsets_a, params, gemm_params, simd_gid, simd_lid);
146 loader_b_t loader_b(
147 B, Bs, offsets_b, params, gemm_params, simd_gid, simd_lid);
148
149 // Prepare threadgroup mma operation
150 mma_t mma_op(simd_gid, simd_lid);
151
152 int gemm_k_iterations = gemm_params->gemm_k_iterations;
153 for (int k = 0; k < gemm_k_iterations; k++) {
154 threadgroup_barrier(mem_flags::mem_threadgroup);
155 // Load elements into threadgroup
156 loader_a.load_unsafe();
157 loader_b.load_unsafe();
158
159 threadgroup_barrier(mem_flags::mem_threadgroup);
160
161 // Multiply and accumulate threadgroup elements
162 mma_op.mma(As, Bs);
163
164 // Prepare for next iteration
165 loader_a.next();
166 loader_b.next();
167 }
168
169 threadgroup_barrier(mem_flags::mem_none);
170
171 // Store results to device memory
172 short tgp_bm = min(BM, gemm_params->M - c_row);
173 short tgp_bn = min(BN, gemm_params->N - c_col);
174 const int ldc = N * params->groups;
175 mma_op.store_result_safe(C, ldc, short2(tgp_bn, tgp_bm));
176}
Definition bf16_math.h:226
METAL_FUNC bfloat16_t min(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:232
Definition attn.h:19
void implicit_gemm_conv_2d(const device T *A, const device T *B, device T *C, const constant MLXConvParams< 2 > *params, const constant ImplicitGemmConv2DParams *gemm_params, uint3 tid, uint3 lid, uint simd_gid, uint simd_lid)
Definition steel_conv.h:17
Definition params.h:6
Definition mma.h:449
Definition loader_channel_l.h:23
Definition loader_channel_n.h:59
Definition loader_channel_l.h:171
Definition loader_channel_l.h:352
Definition loader_channel_n.h:203