MLX
Loading...
Searching...
No Matches
steel_gemm_fused.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3using namespace mlx::steel;
4
6// GEMM kernels
8
9constant bool has_batch [[function_constant(10)]];
10
11constant bool use_out_source [[function_constant(100)]];
12constant bool do_axpby [[function_constant(110)]];
13
14constant bool align_M [[function_constant(200)]];
15constant bool align_N [[function_constant(201)]];
16constant bool align_K [[function_constant(202)]];
17
18constant bool do_gather [[function_constant(300)]];
19
21
22// clang-format off
23template <
24 typename T,
25 int BM,
26 int BN,
27 int BK,
28 int WM,
29 int WN,
30 bool transpose_a,
31 bool transpose_b,
32 typename AccumType = float>
33[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm(
34 const device T* A [[buffer(0)]],
35 const device T* B [[buffer(1)]],
36 const device T* C [[buffer(2), function_constant(use_out_source)]],
37 device T* D [[buffer(3)]],
38 const constant GEMMParams* params [[buffer(4)]],
39 const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
40 const constant int* batch_shape [[buffer(6)]],
41 const constant size_t* batch_strides [[buffer(7)]],
42 const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]],
43 const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]],
44 const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]],
45 const constant int* operand_shape [[buffer(13), function_constant(do_gather)]],
46 const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]],
47 const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]],
48 uint simd_lane_id [[thread_index_in_simdgroup]],
49 uint simd_group_id [[simdgroup_index_in_threadgroup]],
50 uint3 tid [[threadgroup_position_in_grid]],
51 uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on
52 // Pacifying compiler
53 (void)lid;
54
55 using gemm_kernel = GEMMKernel<
56 T,
57 T,
58 BM,
59 BN,
60 BK,
61 WM,
62 WN,
63 transpose_a,
64 transpose_b,
65 true,
66 true,
67 AccumType>;
68
69 using loader_a_t = typename gemm_kernel::loader_a_t;
70 using loader_b_t = typename gemm_kernel::loader_b_t;
71 using mma_t = typename gemm_kernel::mma_t;
72
73 // Find block
74 const int tid_y = ((tid.y) << params->swizzle_log) +
75 ((tid.x) & ((1 << params->swizzle_log) - 1));
76 const int tid_x = (tid.x) >> params->swizzle_log;
77
78 // Exit early if out of bounds
79 if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
80 return;
81 }
82
83 // Adjust for batch
84
85 // Handle gather
86 if (do_gather) {
87 // Read indices
88 uint32_t indx_A, indx_B, indx_C;
89
90 if (has_batch) {
91 const constant size_t* indx_A_bstrides = batch_strides;
92 const constant size_t* indx_B_bstrides =
93 batch_strides + params->batch_ndim;
94
95 ulong2 indx_offsets = elem_to_loc_broadcast(
96 tid.z,
97 batch_shape,
98 indx_A_bstrides,
99 indx_B_bstrides,
100 params->batch_ndim);
101 indx_A = lhs_indices[indx_offsets.x];
102 indx_B = rhs_indices[indx_offsets.y];
103
104 if (use_out_source) {
105 const constant size_t* indx_C_bstrides =
106 indx_B_bstrides + params->batch_ndim;
107 auto indx_offset_C = elem_to_loc(
108 tid.z, batch_shape, indx_C_bstrides, params->batch_ndim);
109 indx_C = C_indices[indx_offset_C];
110 }
111 } else {
112 indx_A = lhs_indices[params->batch_stride_a * tid.z];
113 indx_B = rhs_indices[params->batch_stride_b * tid.z];
114
115 if (use_out_source) {
116 indx_C = C_indices[addmm_params->batch_stride_c * tid.z];
117 }
118 }
119
120 // Translate indices to offsets
121 int batch_ndim_A = operand_batch_ndim.x;
122 const constant int* batch_shape_A = operand_shape;
123 const constant size_t* batch_strides_A = operand_strides;
124 A += elem_to_loc(indx_A, batch_shape_A, batch_strides_A, batch_ndim_A);
125
126 int batch_ndim_B = operand_batch_ndim.y;
127 const constant int* batch_shape_B = batch_shape_A + batch_ndim_A;
128 const constant size_t* batch_strides_B = batch_strides_A + batch_ndim_A;
129 B += elem_to_loc(indx_B, batch_shape_B, batch_strides_B, batch_ndim_B);
130
131 if (use_out_source) {
132 int batch_ndim_C = operand_batch_ndim.z;
133 const constant int* batch_shape_C = batch_shape_B + batch_ndim_B;
134 const constant size_t* batch_strides_C = batch_strides_B + batch_ndim_B;
135 C += elem_to_loc(indx_C, batch_shape_C, batch_strides_C, batch_ndim_C);
136 }
137
138 }
139
140 // Handle regular batch
141 else {
142 if (has_batch) {
143 const constant size_t* A_bstrides = batch_strides;
144 const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
145
146 ulong2 batch_offsets = elem_to_loc_broadcast(
147 tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
148
149 A += batch_offsets.x;
150 B += batch_offsets.y;
151
152 if (use_out_source) {
153 const constant size_t* C_bstrides = B_bstrides + params->batch_ndim;
154 C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim);
155 }
156 } else {
157 A += params->batch_stride_a * tid.z;
158 B += params->batch_stride_b * tid.z;
159
160 if (use_out_source) {
161 C += addmm_params->batch_stride_c * tid.z;
162 }
163 }
164 }
165
166 D += params->batch_stride_d * tid.z;
167
168 // Prepare threadgroup memory
169 threadgroup T As[gemm_kernel::tgp_mem_size_a];
170 threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
171
172 threadgroup_barrier(mem_flags::mem_none);
173
174 // Find block in A, B, C
175 const int c_row = tid_y * BM;
176 const int c_col = tid_x * BN;
177 const size_t c_row_long = size_t(c_row);
178 const size_t c_col_long = size_t(c_col);
179
180 A += transpose_a ? c_row_long : c_row_long * params->lda;
181 B += transpose_b ? c_col_long * params->ldb : c_col_long;
182 D += c_row_long * params->ldd + c_col_long;
183
184 if (use_out_source) {
185 C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc;
186 }
187
188 // Prepare threadgroup mma operation
189 thread mma_t mma_op(simd_group_id, simd_lane_id);
190
191 // Prepare threadgroup loading operations
192 thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
193 thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
194
195 // Prepare threadgroup bounds
196 const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));
197 const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));
198
199 // Prepare iterations
200 int gemm_k_iterations = params->gemm_k_iterations_aligned;
201
202 // Do unaligned K iterations first
203 if (!align_K) {
204 const int k_last = params->gemm_k_iterations_aligned * BK;
205 const int k_remain = params->K - k_last;
206 const size_t k_jump_a =
207 transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
208 const size_t k_jump_b =
209 transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
210
211 // Move loader source ahead to end
212 loader_a.src += k_jump_a;
213 loader_b.src += k_jump_b;
214
215 // Load tile
216 const short2 tile_dims_A =
217 transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
218 const short2 tile_dims_B =
219 transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
220
221 loader_a.load_safe(tile_dims_A);
222 loader_b.load_safe(tile_dims_B);
223
224 threadgroup_barrier(mem_flags::mem_threadgroup);
225
226 // Do matmul
227 mma_op.mma(As, Bs);
228
229 // Reset source back to start
230 loader_a.src -= k_jump_a;
231 loader_b.src -= k_jump_b;
232 }
233
234 const TransformAdd<AccumType, AccumType> epilogue_op_add(
235 addmm_params->alpha, addmm_params->beta);
236 const TransformAxpby<AccumType, AccumType> epilogue_op_axpby(
237 addmm_params->alpha, addmm_params->beta);
238
240 // MNK aligned loop
241 if (align_M && align_N) {
242 // Do gemm
243 for (int k = 0; k < gemm_k_iterations; k++) {
244 threadgroup_barrier(mem_flags::mem_threadgroup);
245 // Load elements into threadgroup
246 loader_a.load_unsafe();
247 loader_b.load_unsafe();
248
249 threadgroup_barrier(mem_flags::mem_threadgroup);
250
251 // Multiply and accumulate threadgroup elements
252 mma_op.mma(As, Bs);
253
254 // Prepare for next iteration
255 loader_a.next();
256 loader_b.next();
257 }
258
259 threadgroup_barrier(mem_flags::mem_none);
260
261 // Do epilogue
262 if (use_out_source) {
263 if (do_axpby) {
264 mma_op.apply_epilogue(
265 C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby);
266 } else {
267 mma_op.apply_epilogue(
268 C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add);
269 }
270 }
271
272 // Store results to device memory
273 return mma_op.store_result(D, params->ldd);
274
275 }
277 // MN unaligned loop
278 else { // Loop over K - unaligned case
279 const int leftover_bk = 0;
280
281 if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
282 // Do gemm
283 gemm_kernel::gemm_loop(
284 As,
285 Bs,
286 gemm_k_iterations,
287 loader_a,
288 loader_b,
289 mma_op,
290 tgp_bm,
291 tgp_bn,
292 leftover_bk,
294
295 // Do epilogue
296 if (use_out_source) {
297 if (do_axpby) {
298 mma_op.apply_epilogue(
299 C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby);
300 } else {
301 mma_op.apply_epilogue(
302 C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add);
303 }
304 }
305
306 // Store results to device memory
307 return mma_op.store_result(D, params->ldd);
308
309 } else if (align_N || tgp_bn == BN) {
310 gemm_kernel::gemm_loop(
311 As,
312 Bs,
313 gemm_k_iterations,
314 loader_a,
315 loader_b,
316 mma_op,
317 tgp_bm,
318 tgp_bn,
319 leftover_bk,
321
322 // Do epilogue
323 if (use_out_source) {
324 if (do_axpby) {
325 mma_op.apply_epilogue_safe(
326 C,
327 addmm_params->ldc,
328 addmm_params->fdc,
329 short2(tgp_bn, tgp_bm),
330 epilogue_op_axpby);
331 } else {
332 mma_op.apply_epilogue_safe(
333 C,
334 addmm_params->ldc,
335 addmm_params->fdc,
336 short2(tgp_bn, tgp_bm),
337 epilogue_op_add);
338 }
339 }
340
341 // Store results to device memory
342 return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
343
344 } else if (align_M || tgp_bm == BM) {
345 gemm_kernel::gemm_loop(
346 As,
347 Bs,
348 gemm_k_iterations,
349 loader_a,
350 loader_b,
351 mma_op,
352 tgp_bm,
353 tgp_bn,
354 leftover_bk,
356
357 // Do epilogue
358 if (use_out_source) {
359 if (do_axpby) {
360 mma_op.apply_epilogue_safe(
361 C,
362 addmm_params->ldc,
363 addmm_params->fdc,
364 short2(tgp_bn, tgp_bm),
365 epilogue_op_axpby);
366 } else {
367 mma_op.apply_epilogue_safe(
368 C,
369 addmm_params->ldc,
370 addmm_params->fdc,
371 short2(tgp_bn, tgp_bm),
372 epilogue_op_add);
373 }
374 }
375
376 // Store results to device memory
377 return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
378
379 } else {
380 gemm_kernel::gemm_loop(
381 As,
382 Bs,
383 gemm_k_iterations,
384 loader_a,
385 loader_b,
386 mma_op,
387 tgp_bm,
388 tgp_bn,
389 leftover_bk,
391
392 // Do epilogue
393 if (use_out_source) {
394 if (do_axpby) {
395 mma_op.apply_epilogue_safe(
396 C,
397 addmm_params->ldc,
398 addmm_params->fdc,
399 short2(tgp_bn, tgp_bm),
400 epilogue_op_axpby);
401 } else {
402 mma_op.apply_epilogue_safe(
403 C,
404 addmm_params->ldc,
405 addmm_params->fdc,
406 short2(tgp_bn, tgp_bm),
407 epilogue_op_add);
408 }
409 }
410
411 // Store results to device memory
412 return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
413 }
414 }
415}
METAL_FUNC ulong2 elem_to_loc_broadcast(uint elem, constant const int *shape, constant const size_t *a_strides, constant const size_t *b_strides, int ndim)
Definition utils.h:7
METAL_FUNC stride_t elem_to_loc(uint elem, device const int *shape, device const stride_t *strides, int ndim)
Definition utils.h:87
Definition loader_channel_l.h:14
constant bool use_out_source
Definition steel_gemm_fused.h:11
constant bool align_M
Definition steel_gemm_fused.h:14
constant bool do_gather
Definition steel_gemm_fused.h:18
constant bool do_axpby
Definition steel_gemm_fused.h:12
constant bool align_K
Definition steel_gemm_fused.h:16
constant bool align_N
Definition steel_gemm_fused.h:15
void gemm(const device T *A, const device T *B, const device T *C, device T *D, const constant GEMMParams *params, const constant GEMMAddMMParams *addmm_params, const constant int *batch_shape, const constant size_t *batch_strides, const constant uint32_t *lhs_indices, const constant uint32_t *rhs_indices, const constant uint32_t *C_indices, const constant int *operand_shape, const constant size_t *operand_strides, const constant packed_int3 &operand_batch_ndim, uint simd_lane_id, uint simd_group_id, uint3 tid, uint3 lid)
Definition steel_gemm_fused.h:33
constant bool gather_bias
Definition steel_gemm_fused.h:20
constant bool has_batch
Definition steel_gemm_fused.h:9
Definition params.h:53
Definition gemm.h:37
Definition params.h:12
Definition gemm.h:21
Definition transforms.h:26
Definition transforms.h:39