MLX
Loading...
Searching...
No Matches
gemm.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3#pragma once
4
10
11using namespace metal;
12
14// GEMM kernel class
16
17namespace mlx {
18namespace steel {
19
20template <bool M_aligned, bool N_aligned, bool K_aligned>
21struct LoopAlignment {};
22
23template <
24 typename T,
25 typename U,
26 int BM,
27 int BN,
28 int BK,
29 int WM,
30 int WN,
31 bool transpose_a,
32 bool transpose_b,
33 bool MN_aligned,
34 bool K_aligned,
35 typename AccumType = typename AccumHelper<T>::accum_type,
36 typename Epilogue = TransformNone<U, AccumType>>
37struct GEMMKernel {
38 STEEL_CONST short tgp_padding_a = 16 / sizeof(T);
39 STEEL_CONST short tgp_padding_b = 16 / sizeof(T);
41 transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
43 transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
45
46 STEEL_CONST short tgp_size = WM * WN * 32;
47
49 T,
50 transpose_a ? BK : BM,
51 transpose_a ? BM : BK,
52 transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
53 !transpose_a,
54 tgp_size>;
56 T,
57 transpose_b ? BN : BK,
58 transpose_b ? BK : BN,
59 transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
60 transpose_b,
61 tgp_size>;
62 using mma_t = BlockMMA<
63 T,
64 U,
65 BM,
66 BN,
67 BK,
68 WM,
69 WN,
70 transpose_a,
71 transpose_b,
72 transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
73 transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
74 AccumType,
75 Epilogue>;
76
77 /* Main kernel function */
78 template <bool M_aligned, bool N_aligned, bool K_aligned_>
79 static METAL_FUNC void gemm_loop(
80 threadgroup T* As [[threadgroup(0)]],
81 threadgroup T* Bs [[threadgroup(1)]],
82 const int gemm_k_iterations,
83 thread loader_a_t& loader_a,
84 thread loader_b_t& loader_b,
85 thread mma_t& mma_op,
86 thread const short& tgp_bm,
87 thread const short& tgp_bn,
88 thread const short& lbk,
90 // Appease the compiler
91 (void)l;
92
93 short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
94
95 short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
96
97 for (int k = 0; k < gemm_k_iterations; k++) {
98 threadgroup_barrier(mem_flags::mem_threadgroup);
99 // Load elements into threadgroup
100 if (M_aligned) {
101 loader_a.load_unsafe();
102 } else {
103 loader_a.load_safe(tile_dims_A);
104 }
105
106 if (N_aligned) {
107 loader_b.load_unsafe();
108 } else {
109 loader_b.load_safe(tile_dims_B);
110 }
111
112 threadgroup_barrier(mem_flags::mem_threadgroup);
113
114 // Multiply and accumulate threadgroup elements
115 mma_op.mma(As, Bs);
116
117 // Prepare for next iteration
118 loader_a.next();
119 loader_b.next();
120 }
121
122 if (!K_aligned_) {
123 threadgroup_barrier(mem_flags::mem_threadgroup);
124
125 short2 tile_dims_A_last =
126 transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
127 short2 tile_dims_B_last =
128 transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
129
130 loader_a.load_safe(tile_dims_A_last);
131 loader_b.load_safe(tile_dims_B_last);
132
133 threadgroup_barrier(mem_flags::mem_threadgroup);
134
135 mma_op.mma(As, Bs);
136 }
137 }
138
139 /* Main kernel function */
140 static METAL_FUNC void run(
141 const device T* A [[buffer(0)]],
142 const device T* B [[buffer(1)]],
143 device U* D [[buffer(2)]],
144 const constant GEMMParams* params [[buffer(3)]],
145 threadgroup T* As [[threadgroup(0)]],
146 threadgroup T* Bs [[threadgroup(1)]],
147 uint simd_lane_id [[thread_index_in_simdgroup]],
148 uint simd_group_id [[simdgroup_index_in_threadgroup]],
149 uint3 tid [[threadgroup_position_in_grid]],
150 uint3 lid [[thread_position_in_threadgroup]]) {
151 // Pacifying compiler
152 (void)lid;
153
154 const int tid_y = ((tid.y) << params->swizzle_log) +
155 ((tid.x) & ((1 << params->swizzle_log) - 1));
156 const int tid_x = (tid.x) >> params->swizzle_log;
157
158 if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
159 return;
160 }
161
162 threadgroup_barrier(mem_flags::mem_none);
163
164 // Find block in A, B, C
165 const int c_row = tid_y * BM;
166 const int c_col = tid_x * BN;
167 const size_t c_row_long = size_t(c_row);
168 const size_t c_col_long = size_t(c_col);
169
170 A += transpose_a ? c_row_long : c_row_long * params->lda;
171 B += transpose_b ? c_col_long * params->ldb : c_col_long;
172 D += c_row_long * params->ldd + c_col_long;
173
174 // Prepare threadgroup loading operations
175 thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
176 thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
177
178 // Prepare threadgroup mma operation
179 thread mma_t mma_op(simd_group_id, simd_lane_id);
180
181 int gemm_k_iterations = params->gemm_k_iterations_aligned;
182
184 // MNK aligned loop
185 if (MN_aligned) {
186 for (int k = 0; k < gemm_k_iterations; k++) {
187 threadgroup_barrier(mem_flags::mem_threadgroup);
188 // Load elements into threadgroup
189 loader_a.load_unsafe();
190 loader_b.load_unsafe();
191
192 threadgroup_barrier(mem_flags::mem_threadgroup);
193
194 // Multiply and accumulate threadgroup elements
195 mma_op.mma(As, Bs);
196
197 // Prepare for next iteration
198 loader_a.next();
199 loader_b.next();
200 }
201
202 threadgroup_barrier(mem_flags::mem_none);
203
204 // Loop tail
205 if (!K_aligned) {
206 int lbk = params->K - params->gemm_k_iterations_aligned * BK;
207 short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
208 short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
209
210 loader_a.load_safe(tile_dims_A);
211 loader_b.load_safe(tile_dims_B);
212
213 threadgroup_barrier(mem_flags::mem_threadgroup);
214
215 mma_op.mma(As, Bs);
216 }
217
218 // Store results to device memory
219 mma_op.store_result(D, params->ldd);
220 return;
221
222 }
224 // MN unaligned loop
225 else { // Loop over K - unaligned case
226 short tgp_bm = min(BM, params->M - c_row);
227 short tgp_bn = min(BN, params->N - c_col);
228 short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
229
230 if (tgp_bm == BM && tgp_bn == BN) {
232 As,
233 Bs,
234 gemm_k_iterations,
235 loader_a,
236 loader_b,
237 mma_op,
238 tgp_bm,
239 tgp_bn,
240 leftover_bk);
241
242 mma_op.store_result(D, params->ldd);
243 return;
244
245 } else if (tgp_bn == BN) {
247 As,
248 Bs,
249 gemm_k_iterations,
250 loader_a,
251 loader_b,
252 mma_op,
253 tgp_bm,
254 tgp_bn,
255 leftover_bk);
256
257 mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
258 return;
259
260 } else if (tgp_bm == BM) {
262 As,
263 Bs,
264 gemm_k_iterations,
265 loader_a,
266 loader_b,
267 mma_op,
268 tgp_bm,
269 tgp_bn,
270 leftover_bk);
271
272 mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
273 return;
274
275 } else {
277 As,
278 Bs,
279 gemm_k_iterations,
280 loader_a,
281 loader_b,
282 mma_op,
283 tgp_bm,
284 tgp_bn,
285 leftover_bk);
286
287 mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
288 return;
289 }
290 }
291 }
292};
293
294} // namespace steel
295} // namespace mlx
Definition bf16.h:265
METAL_FUNC bfloat16_t min(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:234
Definition allocator.h:7
#define STEEL_CONST
Definition defines.h:3
float accum_type
Definition transforms.h:57
Definition loader.h:25
Definition mma.h:377
Definition gemm.h:37
static METAL_FUNC void run(const device T *A, const device T *B, device U *D, const constant GEMMParams *params, threadgroup T *As, threadgroup T *Bs, uint simd_lane_id, uint simd_group_id, uint3 tid, uint3 lid)
Definition gemm.h:140
STEEL_CONST short tgp_mem_size_b
Definition gemm.h:42
STEEL_CONST short tgp_mem_size
Definition gemm.h:44
static METAL_FUNC void gemm_loop(threadgroup T *As, threadgroup T *Bs, const int gemm_k_iterations, thread loader_a_t &loader_a, thread loader_b_t &loader_b, thread mma_t &mma_op, thread const short &tgp_bm, thread const short &tgp_bn, thread const short &lbk, LoopAlignment< M_aligned, N_aligned, K_aligned_ > l={})
Definition gemm.h:79
STEEL_CONST short tgp_size
Definition gemm.h:46
STEEL_CONST short tgp_mem_size_a
Definition gemm.h:40
STEEL_CONST short tgp_padding_b
Definition gemm.h:39
STEEL_CONST short tgp_padding_a
Definition gemm.h:38
Definition params.h:12
Definition gemm.h:21
Definition transforms.h:15