MLX
Loading...
Searching...
No Matches
mma.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3#pragma once
4
5#include <metal_simdgroup>
6#include <metal_simdgroup_matrix>
7#include <metal_stdlib>
8
11
12using namespace metal;
13
15// MMA helper
17
18namespace mlx {
19namespace steel {
20
21template <
22 typename T,
23 typename U,
24 int BM,
25 int BN,
26 int BK,
27 int WM,
28 int WN,
29 bool transpose_a,
30 bool transpose_b,
31 short lda_tgp,
32 short ldb_tgp,
33 typename AccumType = float,
34 typename Epilogue = TransformNone<U, AccumType>>
35struct BlockMMA {
36 // Warp tile simdgroup matrix strides along M
37 STEEL_CONST short TM_stride = 8 * WM;
38 // Warp tile simdgroup matrix strides along M
39 STEEL_CONST short TN_stride = 8 * WN;
40
41 // Warp tile size along M
42 STEEL_CONST short TM = BM / TM_stride;
43 // Warp tile size along N
44 STEEL_CONST short TN = BN / TN_stride;
45
46 // Strides of A, B along reduction axis
48 transpose_a ? TM_stride : TM_stride * lda_tgp};
50 transpose_b ? TN_stride * ldb_tgp : TN_stride};
51
52 // Jump between elements
53 STEEL_CONST short jump_a = {transpose_a ? lda_tgp : 1};
54 STEEL_CONST short jump_b = {transpose_b ? ldb_tgp : 1};
55
56 STEEL_CONST short tile_stride_a = {transpose_a ? 8 * lda_tgp : 8};
57 STEEL_CONST short tile_stride_b = {transpose_b ? 8 : 8 * ldb_tgp};
58
59 // Simdgroup matrices
60 simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
61 simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
62 simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
63 simdgroup_matrix<AccumType, 8, 8>(0)};
64
65 // Offsets within threadgroup
66 const short tm;
67 const short tn;
68
69 short sm;
70 short sn;
71
72 short As_offset;
73 short Bs_offset;
74
75 /* Constructor */
76 METAL_FUNC BlockMMA(
77 ushort simd_group_id [[simdgroup_index_in_threadgroup]],
78 ushort simd_lane_id [[thread_index_in_simdgroup]])
79 : tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
80 // Determine thread position in simdgroup matrix
81 short qid = simd_lane_id / 4;
82 sm = (qid & 4) + (simd_lane_id / 2) % 4;
83 sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
84
85 // Determine thread and simdgroup offset
86 As_offset =
87 transpose_a ? ((sn)*lda_tgp + (tm + sm)) : ((sn) + (tm + sm) * lda_tgp);
88 Bs_offset =
89 transpose_b ? ((tn + sn) * ldb_tgp + (sm)) : ((sm)*ldb_tgp + (tn + sn));
90 }
91
92 /* (BM, BK) X (BK, BN) multiply accumulate function */
93 METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
94 // Adjust for simdgroup and thread location
95 As += As_offset;
96 Bs += Bs_offset;
97
98 // Iterate over BK in blocks of 8
100 for (short kk = 0; kk < BK; kk += 8) {
101 simdgroup_barrier(mem_flags::mem_none);
102
103 // Load elements from threadgroup A as simdgroup matrices
105 for (short i = 0; i < TM; i++) {
106 Asimd[i].thread_elements()[0] =
107 static_cast<AccumType>(As[i * simd_stride_a + 0]);
108 Asimd[i].thread_elements()[1] =
109 static_cast<AccumType>(As[i * simd_stride_a + jump_a]);
110 }
111
112 simdgroup_barrier(mem_flags::mem_none);
113
114 // Load elements from threadgroup B as simdgroup matrices
116 for (short j = 0; j < TN; j++) {
117 Bsimd[j].thread_elements()[0] =
118 static_cast<AccumType>(Bs[j * simd_stride_b + 0]);
119 Bsimd[j].thread_elements()[1] =
120 static_cast<AccumType>(Bs[j * simd_stride_b + jump_b]);
121 }
122
123 simdgroup_barrier(mem_flags::mem_none);
124
125 // Multiply and accumulate into result simdgroup matrices
127 for (short i = 0; i < TM; i++) {
129 for (short j = 0; j < TN; j++) {
130 short j_serp = (i % 2) ? (TN - 1 - j) : j;
131
132 simdgroup_multiply_accumulate(
133 results[i * TN + j_serp],
134 Asimd[i],
135 Bsimd[j_serp],
136 results[i * TN + j_serp]);
137 }
138 }
139
140 // Progress to next simdgroup tile
141 As += tile_stride_a;
142 Bs += tile_stride_b;
143 }
144 }
145
146 /* Store results from simdgroup_matrix results into device memory */
147 METAL_FUNC void store_result(device U* D, const int ldd) const {
148 // Adjust for simdgroup and thread location
149 D += (sm + tm) * ldd + tn + sn;
150
151 // Loop over all simdgroup tiles
153 for (short i = 0; i < TM; i++) {
155 for (short j = 0; j < TN; j++) {
156 // Get accumulated result and associated offset in C
157 thread const auto& accum = results[i * TN + j].thread_elements();
158 int offset = (i * TM_stride) * ldd + (j * TN_stride);
159
160 // Apply epilogue
161 U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])};
162
163 // Write out D
164 D[offset] = outs[0];
165 D[offset + 1] = outs[1];
166 }
167 }
168 }
169
170 METAL_FUNC void
171 store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) const {
172 // Adjust for simdgroup and thread location
173 D += (sm + tm) * ldd + (tn + sn);
174 dst_tile_dims -= short2(tn + sn, sm + tm);
175
176 if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
177 return;
178
180 for (int i = 0; i < TM; i++) {
181 if (i * TM_stride < dst_tile_dims.y) {
183 for (int j = 0; j < TN; j++) {
184 // Get accumulated result and associated offset in C
185 thread const auto& accum = results[i * TN + j].thread_elements();
186 int offset = (i * TM_stride) * ldd + (j * TN_stride);
187
188 // Apply epilogue and output C
189 if (j * TN_stride < dst_tile_dims.x) {
190 D[offset] = Epilogue::apply(accum[0]);
191 }
192
193 if (j * TN_stride + 1 < dst_tile_dims.x) {
194 D[offset + 1] = Epilogue::apply(accum[1]);
195 }
196 }
197 }
198 }
199 }
200
201 /* Apply epilogue */
202 template <typename UnaryEpilogue>
203 METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) {
204 // Loop over all simdgroup tiles
206 for (short i = 0; i < TM; i++) {
208 for (short j = 0; j < TN; j++) {
209 // Get accumulated result and associated offset in C
210 thread auto& accum = results[i * TN + j].thread_elements();
211
212 // Apply epilogue
213 accum[0] = epilogue_op.apply(accum[0]);
214 accum[1] = epilogue_op.apply(accum[1]);
215 }
216 }
217 }
218
219 /* Apply epilogue */
220 template <typename BinaryEpilogue>
221 METAL_FUNC void apply_epilogue(
222 const device U* C,
223 const int ldc,
224 const int fdc,
225 thread const BinaryEpilogue& epilogue_op) {
226 // Adjust for simdgroup and thread location
227 C += (sm + tm) * ldc + (tn + sn) * fdc;
228
229 // Loop over all simdgroup tiles
231 for (short i = 0; i < TM; i++) {
233 for (short j = 0; j < TN; j++) {
234 // Get accumulated result and associated offset in C
235 thread auto& accum = results[i * TN + j].thread_elements();
236 int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
237
238 // Apply epilogue
239 accum[0] = epilogue_op.apply(accum[0], C[offset_c]);
240 accum[1] = epilogue_op.apply(accum[1], C[offset_c + fdc]);
241 }
242 }
243 }
244
245 /* Apply epilogue */
246 template <typename BinaryEpilogue>
247 METAL_FUNC void apply_epilogue_safe(
248 const device U* C,
249 const int ldc,
250 const int fdc,
251 short2 dst_tile_dims,
252 thread const BinaryEpilogue& epilogue_op) {
253 // Adjust for simdgroup and thread location
254 C += (sm + tm) * ldc + (tn + sn) * fdc;
255 dst_tile_dims -= short2(tn + sn, sm + tm);
256
257 if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
258 return;
259
260 // Loop over all simdgroup tiles
262 for (short i = 0; i < TM; i++) {
264 for (short j = 0; j < TN; j++) {
265 // Get accumulated result and associated offset in C
266 thread auto& accum = results[i * TN + j].thread_elements();
267 int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
268
269 // Read C
270 U c_elems[2] = {0};
271
272 if ((j * TN_stride + 1) < dst_tile_dims.x) {
273 c_elems[0] = C[offset_c];
274 c_elems[1] = C[offset_c + fdc];
275 } else if ((j * TN_stride) < dst_tile_dims.x) {
276 c_elems[0] = C[offset_c];
277 }
278
279 // Apply epilogue
280 accum[0] = epilogue_op.apply(accum[0], c_elems[0]);
281 accum[1] = epilogue_op.apply(accum[1], c_elems[1]);
282 }
283 }
284 }
285
286 /* Store results from simdgroup_matrix results into device memory */
287 METAL_FUNC void store_result(
288 device U* D,
289 const int ldd,
290 const device U* C,
291 const int ldc,
292 const int fdc,
293 thread const Epilogue& epilogue_op) const {
294 // Adjust for simdgroup and thread location
295 C += (sm + tm) * ldc + (tn + sn) * fdc;
296 D += (sm + tm) * ldd + tn + sn;
297
298 // Loop over all simdgroup tiles
300 for (short i = 0; i < TM; i++) {
302 for (short j = 0; j < TN; j++) {
303 // Get accumulated result and associated offset in C
304 thread const auto& accum = results[i * TN + j].thread_elements();
305 int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
306 int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
307
308 // Apply epilogue
309 U outs[2] = {
310 epilogue_op.apply(accum[0], C[offset_c]),
311 epilogue_op.apply(accum[1], C[offset_c + fdc])};
312
313 // Write out D
314 D[offset_d] = outs[0];
315 D[offset_d + 1] = outs[1];
316 }
317 }
318 }
319
320 METAL_FUNC void store_result_safe(
321 device U* D,
322 const int ldd,
323 const device U* C,
324 const int ldc,
325 const int fdc,
326 short2 dst_tile_dims,
327 thread const Epilogue& epilogue_op) const {
328 // Adjust for simdgroup and thread location
329 C += (sm + tm) * ldc + (tn + sn) * fdc;
330 D += (sm + tm) * ldd + tn + sn;
331 dst_tile_dims -= short2(tn + sn, sm + tm);
332
333 if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
334 return;
335
337 for (int i = 0; i < TM; i++) {
338 if (i * TM_stride < dst_tile_dims.y) {
340 for (int j = 0; j < TN; j++) {
341 // Get accumulated result and associated offset in C
342 thread const auto& accum = results[i * TN + j].thread_elements();
343 int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
344 int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
345
346 // Apply epilogue and output C
347 if (j * TN_stride < dst_tile_dims.x) {
348 D[offset_d] = epilogue_op.apply(accum[0], C[offset_c]);
349 }
350
351 if (j * TN_stride + 1 < dst_tile_dims.x) {
352 D[offset_d + 1] = epilogue_op.apply(accum[1], C[offset_c + fdc]);
353 }
354 }
355 }
356 }
357 }
358};
359
360} // namespace steel
361} // namespace mlx
Definition bf16.h:265
Definition allocator.h:7
#define STEEL_PRAGMA_UNROLL
Definition defines.h:4
#define STEEL_CONST
Definition defines.h:3
Definition mma.h:35
short As_offset
Definition mma.h:72
STEEL_CONST short jump_b
Definition mma.h:54
simdgroup_matrix< AccumType, 8, 8 > Bsimd[TN]
Definition mma.h:61
STEEL_CONST short TM_stride
Definition mma.h:37
METAL_FUNC void mma(const threadgroup T *As, const threadgroup T *Bs)
Definition mma.h:93
simdgroup_matrix< AccumType, 8, 8 > results[TM *TN]
Definition mma.h:62
STEEL_CONST short TN
Definition mma.h:44
METAL_FUNC void store_result_safe(device U *D, const int ldd, const device U *C, const int ldc, const int fdc, short2 dst_tile_dims, thread const Epilogue &epilogue_op) const
Definition mma.h:320
METAL_FUNC void store_result(device U *D, const int ldd, const device U *C, const int ldc, const int fdc, thread const Epilogue &epilogue_op) const
Definition mma.h:287
METAL_FUNC void apply_epilogue(const device U *C, const int ldc, const int fdc, thread const BinaryEpilogue &epilogue_op)
Definition mma.h:221
METAL_FUNC void store_result(device U *D, const int ldd) const
Definition mma.h:147
STEEL_CONST short TN_stride
Definition mma.h:39
STEEL_CONST short tile_stride_a
Definition mma.h:56
simdgroup_matrix< AccumType, 8, 8 > Asimd[TM]
Definition mma.h:60
short Bs_offset
Definition mma.h:73
METAL_FUNC void apply_epilogue_safe(const device U *C, const int ldc, const int fdc, short2 dst_tile_dims, thread const BinaryEpilogue &epilogue_op)
Definition mma.h:247
METAL_FUNC BlockMMA(ushort simd_group_id, ushort simd_lane_id)
Definition mma.h:76
short sm
Definition mma.h:69
STEEL_CONST short simd_stride_a
Definition mma.h:47
const short tm
Definition mma.h:66
STEEL_CONST short TM
Definition mma.h:42
const short tn
Definition mma.h:67
STEEL_CONST short jump_a
Definition mma.h:53
short sn
Definition mma.h:70
STEEL_CONST short tile_stride_b
Definition mma.h:57
STEEL_CONST short simd_stride_b
Definition mma.h:49
METAL_FUNC void apply_epilogue(thread const UnaryEpilogue &epilogue_op)
Definition mma.h:203
METAL_FUNC void store_result_safe(device U *D, const int ldd, short2 dst_tile_dims) const
Definition mma.h:171