MLX
 
Loading...
Searching...
No Matches
steel_gemm_fused.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3using namespace mlx::steel;
4
6// GEMM kernels
8
9constant bool has_batch [[function_constant(10)]];
10
11constant bool use_out_source [[function_constant(100)]];
12constant bool do_axpby [[function_constant(110)]];
13
14constant bool align_M [[function_constant(200)]];
15constant bool align_N [[function_constant(201)]];
16constant bool align_K [[function_constant(202)]];
17
18constant bool do_gather [[function_constant(300)]];
19
21
22// clang-format off
23template <
24 typename T,
25 int BM,
26 int BN,
27 int BK,
28 int WM,
29 int WN,
30 bool transpose_a,
31 bool transpose_b,
32 typename AccumType = float>
33[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm(
34 const device T* A [[buffer(0)]],
35 const device T* B [[buffer(1)]],
36 const device T* C [[buffer(2), function_constant(use_out_source)]],
37 device T* D [[buffer(3)]],
38 const constant GEMMParams* params [[buffer(4)]],
39 const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
40 const constant int* batch_shape [[buffer(6)]],
41 const constant int64_t* batch_strides [[buffer(7)]],
42 const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]],
43 const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]],
44 const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]],
45 const constant int* operand_shape [[buffer(13), function_constant(do_gather)]],
46 const constant int64_t* operand_strides [[buffer(14), function_constant(do_gather)]],
47 const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]],
48 uint simd_lane_id [[thread_index_in_simdgroup]],
49 uint simd_group_id [[simdgroup_index_in_threadgroup]],
50 uint3 tid [[threadgroup_position_in_grid]],
51 uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on
52 // Pacifying compiler
53 (void)lid;
54
55 using gemm_kernel = GEMMKernel<
56 T,
57 T,
58 BM,
59 BN,
60 BK,
61 WM,
62 WN,
63 transpose_a,
64 transpose_b,
65 true,
66 true,
67 AccumType>;
68
69 using loader_a_t = typename gemm_kernel::loader_a_t;
70 using loader_b_t = typename gemm_kernel::loader_b_t;
71 using mma_t = typename gemm_kernel::mma_t;
72
73 // Find block
74 const int tid_y = ((tid.y) << params->swizzle_log) +
75 ((tid.x) & ((1 << params->swizzle_log) - 1));
76 const int tid_x = (tid.x) >> params->swizzle_log;
77
78 // Exit early if out of bounds
79 if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
80 return;
81 }
82
83 // Adjust for batch
84
85 // Handle gather
86 if (do_gather) {
87 // Read indices
88 uint32_t indx_A, indx_B, indx_C;
89
90 if (has_batch) {
91 const constant auto* indx_A_bstrides = batch_strides;
92 const constant auto* indx_B_bstrides = batch_strides + params->batch_ndim;
93
94 ulong2 indx_offsets = elem_to_loc_broadcast(
95 tid.z,
96 batch_shape,
97 indx_A_bstrides,
98 indx_B_bstrides,
99 params->batch_ndim);
100 indx_A = lhs_indices[indx_offsets.x];
101 indx_B = rhs_indices[indx_offsets.y];
102
103 if (use_out_source) {
104 const constant auto* indx_C_bstrides =
105 indx_B_bstrides + params->batch_ndim;
106 auto indx_offset_C = elem_to_loc(
107 tid.z, batch_shape, indx_C_bstrides, params->batch_ndim);
108 indx_C = C_indices[indx_offset_C];
109 }
110 } else {
111 indx_A = lhs_indices[params->batch_stride_a * tid.z];
112 indx_B = rhs_indices[params->batch_stride_b * tid.z];
113
114 if (use_out_source) {
115 indx_C = C_indices[addmm_params->batch_stride_c * tid.z];
116 }
117 }
118
119 // Translate indices to offsets
120 int batch_ndim_A = operand_batch_ndim.x;
121 const constant int* batch_shape_A = operand_shape;
122 const constant auto* batch_strides_A = operand_strides;
123 A += elem_to_loc(indx_A, batch_shape_A, batch_strides_A, batch_ndim_A);
124
125 int batch_ndim_B = operand_batch_ndim.y;
126 const constant int* batch_shape_B = batch_shape_A + batch_ndim_A;
127 const constant auto* batch_strides_B = batch_strides_A + batch_ndim_A;
128 B += elem_to_loc(indx_B, batch_shape_B, batch_strides_B, batch_ndim_B);
129
130 if (use_out_source) {
131 int batch_ndim_C = operand_batch_ndim.z;
132 const constant int* batch_shape_C = batch_shape_B + batch_ndim_B;
133 const constant auto* batch_strides_C = batch_strides_B + batch_ndim_B;
134 C += elem_to_loc(indx_C, batch_shape_C, batch_strides_C, batch_ndim_C);
135 }
136
137 }
138
139 // Handle regular batch
140 else {
141 if (has_batch) {
142 const constant auto* A_bstrides = batch_strides;
143 const constant auto* B_bstrides = batch_strides + params->batch_ndim;
144
145 ulong2 batch_offsets = elem_to_loc_broadcast(
146 tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
147
148 A += batch_offsets.x;
149 B += batch_offsets.y;
150
151 if (use_out_source) {
152 const constant auto* C_bstrides = B_bstrides + params->batch_ndim;
153 C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim);
154 }
155 } else {
156 A += params->batch_stride_a * tid.z;
157 B += params->batch_stride_b * tid.z;
158
159 if (use_out_source) {
160 C += addmm_params->batch_stride_c * tid.z;
161 }
162 }
163 }
164
165 D += params->batch_stride_d * tid.z;
166
167 // Prepare threadgroup memory
168 threadgroup T As[gemm_kernel::tgp_mem_size_a];
169 threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
170
171 threadgroup_barrier(mem_flags::mem_none);
172
173 // Find block in A, B, C
174 const int c_row = tid_y * BM;
175 const int c_col = tid_x * BN;
176 const size_t c_row_long = size_t(c_row);
177 const size_t c_col_long = size_t(c_col);
178
179 A += transpose_a ? c_row_long : c_row_long * params->lda;
180 B += transpose_b ? c_col_long * params->ldb : c_col_long;
181 D += c_row_long * params->ldd + c_col_long;
182
183 if (use_out_source) {
184 C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc;
185 }
186
187 // Prepare threadgroup mma operation
188 thread mma_t mma_op(simd_group_id, simd_lane_id);
189
190 // Prepare threadgroup loading operations
191 thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
192 thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
193
194 // Prepare threadgroup bounds
195 const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));
196 const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));
197
198 // Prepare iterations
199 int gemm_k_iterations = params->gemm_k_iterations_aligned;
200
201 // Do unaligned K iterations first
202 if (!align_K) {
203 const int k_last = params->gemm_k_iterations_aligned * BK;
204 const int k_remain = params->K - k_last;
205 const size_t k_jump_a =
206 transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
207 const size_t k_jump_b =
208 transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
209
210 // Move loader source ahead to end
211 loader_a.src += k_jump_a;
212 loader_b.src += k_jump_b;
213
214 // Load tile
215 const short2 tile_dims_A =
216 transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
217 const short2 tile_dims_B =
218 transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
219
220 loader_a.load_safe(tile_dims_A);
221 loader_b.load_safe(tile_dims_B);
222
223 threadgroup_barrier(mem_flags::mem_threadgroup);
224
225 // Do matmul
226 mma_op.mma(As, Bs);
227
228 // Reset source back to start
229 loader_a.src -= k_jump_a;
230 loader_b.src -= k_jump_b;
231 }
232
233 const TransformAdd<AccumType, AccumType> epilogue_op_add(
234 addmm_params->alpha, addmm_params->beta);
235 const TransformAxpby<AccumType, AccumType> epilogue_op_axpby(
236 addmm_params->alpha, addmm_params->beta);
237
239 // MNK aligned loop
240 if (align_M && align_N) {
241 // Do gemm
242 for (int k = 0; k < gemm_k_iterations; k++) {
243 threadgroup_barrier(mem_flags::mem_threadgroup);
244 // Load elements into threadgroup
245 loader_a.load_unsafe();
246 loader_b.load_unsafe();
247
248 threadgroup_barrier(mem_flags::mem_threadgroup);
249
250 // Multiply and accumulate threadgroup elements
251 mma_op.mma(As, Bs);
252
253 // Prepare for next iteration
254 loader_a.next();
255 loader_b.next();
256 }
257
258 threadgroup_barrier(mem_flags::mem_none);
259
260 // Do epilogue
261 if (use_out_source) {
262 if (do_axpby) {
263 mma_op.apply_epilogue(
264 C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby);
265 } else {
266 mma_op.apply_epilogue(
267 C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add);
268 }
269 }
270
271 // Store results to device memory
272 return mma_op.store_result(D, params->ldd);
273
274 }
276 // MN unaligned loop
277 else { // Loop over K - unaligned case
278 const int leftover_bk = 0;
279
280 if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
281 // Do gemm
282 gemm_kernel::gemm_loop(
283 As,
284 Bs,
285 gemm_k_iterations,
286 loader_a,
287 loader_b,
288 mma_op,
289 tgp_bm,
290 tgp_bn,
291 leftover_bk,
293
294 // Do epilogue
295 if (use_out_source) {
296 if (do_axpby) {
297 mma_op.apply_epilogue(
298 C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby);
299 } else {
300 mma_op.apply_epilogue(
301 C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add);
302 }
303 }
304
305 // Store results to device memory
306 return mma_op.store_result(D, params->ldd);
307
308 } else if (align_N || tgp_bn == BN) {
309 gemm_kernel::gemm_loop(
310 As,
311 Bs,
312 gemm_k_iterations,
313 loader_a,
314 loader_b,
315 mma_op,
316 tgp_bm,
317 tgp_bn,
318 leftover_bk,
320
321 // Do epilogue
322 if (use_out_source) {
323 if (do_axpby) {
324 mma_op.apply_epilogue_safe(
325 C,
326 addmm_params->ldc,
327 addmm_params->fdc,
328 short2(tgp_bn, tgp_bm),
329 epilogue_op_axpby);
330 } else {
331 mma_op.apply_epilogue_safe(
332 C,
333 addmm_params->ldc,
334 addmm_params->fdc,
335 short2(tgp_bn, tgp_bm),
336 epilogue_op_add);
337 }
338 }
339
340 // Store results to device memory
341 return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
342
343 } else if (align_M || tgp_bm == BM) {
344 gemm_kernel::gemm_loop(
345 As,
346 Bs,
347 gemm_k_iterations,
348 loader_a,
349 loader_b,
350 mma_op,
351 tgp_bm,
352 tgp_bn,
353 leftover_bk,
355
356 // Do epilogue
357 if (use_out_source) {
358 if (do_axpby) {
359 mma_op.apply_epilogue_safe(
360 C,
361 addmm_params->ldc,
362 addmm_params->fdc,
363 short2(tgp_bn, tgp_bm),
364 epilogue_op_axpby);
365 } else {
366 mma_op.apply_epilogue_safe(
367 C,
368 addmm_params->ldc,
369 addmm_params->fdc,
370 short2(tgp_bn, tgp_bm),
371 epilogue_op_add);
372 }
373 }
374
375 // Store results to device memory
376 return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
377
378 } else {
379 gemm_kernel::gemm_loop(
380 As,
381 Bs,
382 gemm_k_iterations,
383 loader_a,
384 loader_b,
385 mma_op,
386 tgp_bm,
387 tgp_bn,
388 leftover_bk,
390
391 // Do epilogue
392 if (use_out_source) {
393 if (do_axpby) {
394 mma_op.apply_epilogue_safe(
395 C,
396 addmm_params->ldc,
397 addmm_params->fdc,
398 short2(tgp_bn, tgp_bm),
399 epilogue_op_axpby);
400 } else {
401 mma_op.apply_epilogue_safe(
402 C,
403 addmm_params->ldc,
404 addmm_params->fdc,
405 short2(tgp_bn, tgp_bm),
406 epilogue_op_add);
407 }
408 }
409
410 // Store results to device memory
411 return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
412 }
413 }
414}
METAL_FUNC ulong2 elem_to_loc_broadcast(uint elem, constant const int *shape, constant const int64_t *a_strides, constant const int64_t *b_strides, int ndim)
Definition utils.h:7
METAL_FUNC IdxT elem_to_loc(IdxT elem, constant const int *shape, constant const int64_t *strides, int ndim)
Definition utils.h:93
Definition attn.h:19
constant bool align_K
Definition steel_attention.h:10
constant bool use_out_source
Definition steel_gemm_fused.h:11
constant bool align_M
Definition steel_gemm_fused.h:14
constant bool do_gather
Definition steel_gemm_fused.h:18
constant bool do_axpby
Definition steel_gemm_fused.h:12
void gemm(const device T *A, const device T *B, const device T *C, device T *D, const constant GEMMParams *params, const constant GEMMAddMMParams *addmm_params, const constant int *batch_shape, const constant int64_t *batch_strides, const constant uint32_t *lhs_indices, const constant uint32_t *rhs_indices, const constant uint32_t *C_indices, const constant int *operand_shape, const constant int64_t *operand_strides, const constant packed_int3 &operand_batch_ndim, uint simd_lane_id, uint simd_group_id, uint3 tid, uint3 lid)
Definition steel_gemm_fused.h:33
constant bool align_N
Definition steel_gemm_fused.h:15
constant bool gather_bias
Definition steel_gemm_fused.h:20
constant bool has_batch
Definition steel_gemm_fused.h:9
Definition params.h:53
Definition attn.h:38
Definition params.h:12
Definition attn.h:22
Definition transforms.h:26
Definition transforms.h:39