MLX
Loading...
Searching...
No Matches
mma.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3#pragma once
4
5#include <metal_simdgroup>
6#include <metal_simdgroup_matrix>
7#include <metal_stdlib>
8
11
12using namespace metal;
13
15// MMA helper
17
18namespace mlx {
19namespace steel {
20
21template <
22 typename T,
23 typename U,
24 int BM,
25 int BN,
26 int BK,
27 int WM,
28 int WN,
29 bool transpose_a,
30 bool transpose_b,
31 short lda_tgp,
32 short ldb_tgp,
33 typename AccumType = float,
34 typename Epilogue = TransformNone<U, AccumType>>
35struct BlockMMA {
36 // Warp tile simdgroup matrix strides along M
37 STEEL_CONST short TM_stride = 8 * WM;
38 // Warp tile simdgroup matrix strides along M
39 STEEL_CONST short TN_stride = 8 * WN;
40
41 // Warp tile size along M
42 STEEL_CONST short TM = BM / TM_stride;
43 // Warp tile size along N
44 STEEL_CONST short TN = BN / TN_stride;
45
46 // Strides of A, B along reduction axis
48 transpose_a ? TM_stride : TM_stride * lda_tgp};
50 transpose_b ? TN_stride * ldb_tgp : TN_stride};
51
52 // Jump between elements
53 STEEL_CONST short jump_a = {transpose_a ? lda_tgp : 1};
54 STEEL_CONST short jump_b = {transpose_b ? ldb_tgp : 1};
55
56 STEEL_CONST short tile_stride_a = {transpose_a ? 8 * lda_tgp : 8};
57 STEEL_CONST short tile_stride_b = {transpose_b ? 8 : 8 * ldb_tgp};
58
59 // Simdgroup matrices
60 simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
61 simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
62 simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
63 simdgroup_matrix<AccumType, 8, 8>(0)};
64
65 // Offsets within threadgroup
66 const short tm;
67 const short tn;
68
69 short sm;
70 short sn;
71
72 short As_offset;
73 short Bs_offset;
74
75 /* Constructor */
76 METAL_FUNC BlockMMA(
77 ushort simd_group_id [[simdgroup_index_in_threadgroup]],
78 ushort simd_lane_id [[thread_index_in_simdgroup]])
79 : tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
80 // Determine thread position in simdgroup matrix
81 short qid = simd_lane_id / 4;
82 sm = (qid & 4) + (simd_lane_id / 2) % 4;
83 sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
84
85 // Determine thread and simdgroup offset
86 As_offset =
87 transpose_a ? ((sn)*lda_tgp + (tm + sm)) : ((sn) + (tm + sm) * lda_tgp);
88 Bs_offset =
89 transpose_b ? ((tn + sn) * ldb_tgp + (sm)) : ((sm)*ldb_tgp + (tn + sn));
90 }
91
92 /* (BM, BK) X (BK, BN) multiply accumulate function */
93 METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
94 // Adjust for simdgroup and thread location
95 As += As_offset;
96 Bs += Bs_offset;
97
98 // Iterate over BK in blocks of 8
100 for (short kk = 0; kk < BK; kk += 8) {
101 simdgroup_barrier(mem_flags::mem_none);
102
103 // Load elements from threadgroup A as simdgroup matrices
105 for (short i = 0; i < TM; i++) {
106 Asimd[i].thread_elements()[0] =
107 static_cast<AccumType>(As[i * simd_stride_a + 0]);
108 Asimd[i].thread_elements()[1] =
109 static_cast<AccumType>(As[i * simd_stride_a + jump_a]);
110 }
111
112 simdgroup_barrier(mem_flags::mem_none);
113
114 // Load elements from threadgroup B as simdgroup matrices
116 for (short j = 0; j < TN; j++) {
117 Bsimd[j].thread_elements()[0] =
118 static_cast<AccumType>(Bs[j * simd_stride_b + 0]);
119 Bsimd[j].thread_elements()[1] =
120 static_cast<AccumType>(Bs[j * simd_stride_b + jump_b]);
121 }
122
123 simdgroup_barrier(mem_flags::mem_none);
124
125 // Multiply and accumulate into result simdgroup matrices
127 for (short i = 0; i < TM; i++) {
129 for (short j = 0; j < TN; j++) {
130 short j_serp = (i % 2) ? (TN - 1 - j) : j;
131
132 simdgroup_multiply_accumulate(
133 results[i * TN + j_serp],
134 Asimd[i],
135 Bsimd[j_serp],
136 results[i * TN + j_serp]);
137 }
138 }
139
140 // Progress to next simdgroup tile
141 As += tile_stride_a;
142 Bs += tile_stride_b;
143 }
144 }
145
146 /* Store results from simdgroup_matrix results into device memory */
147 METAL_FUNC void store_result(device U* D, const int ldd) const {
148 // Adjust for simdgroup and thread location
149 D += (sm + tm) * ldd + tn + sn;
150
151 // Loop over all simdgroup tiles
153 for (short i = 0; i < TM; i++) {
155 for (short j = 0; j < TN; j++) {
156 // Get accumulated result and associated offset in C
157 thread const auto& accum = results[i * TN + j].thread_elements();
158 int offset = (i * TM_stride) * ldd + (j * TN_stride);
159
160 // Apply epilogue
161 U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])};
162
163 // Write out D
164 D[offset] = outs[0];
165 D[offset + 1] = outs[1];
166 }
167 }
168 }
169
170 METAL_FUNC void
171 store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) const {
172 // Adjust for simdgroup and thread location
173 D += (sm + tm) * ldd + (tn + sn);
174 dst_tile_dims -= short2(tn + sn, sm + tm);
175
176 if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
177 return;
178
180 for (int i = 0; i < TM; i++) {
181 if (i * TM_stride < dst_tile_dims.y) {
183 for (int j = 0; j < TN; j++) {
184 // Get accumulated result and associated offset in C
185 thread const auto& accum = results[i * TN + j].thread_elements();
186 int offset = (i * TM_stride) * ldd + (j * TN_stride);
187
188 // Apply epilogue and output C
189 if (j * TN_stride < dst_tile_dims.x) {
190 D[offset] = Epilogue::apply(accum[0]);
191 }
192
193 if (j * TN_stride + 1 < dst_tile_dims.x) {
194 D[offset + 1] = Epilogue::apply(accum[1]);
195 }
196 }
197 }
198 }
199 }
200
201 /* Apply epilogue */
202 template <typename BinaryEpilogue>
203 METAL_FUNC void apply_epilogue(
204 const device U* C,
205 const int ldc,
206 const int fdc,
207 thread const BinaryEpilogue& epilogue_op) {
208 // Adjust for simdgroup and thread location
209 C += (sm + tm) * ldc + (tn + sn) * fdc;
210
211 // Loop over all simdgroup tiles
213 for (short i = 0; i < TM; i++) {
215 for (short j = 0; j < TN; j++) {
216 // Get accumulated result and associated offset in C
217 thread auto& accum = results[i * TN + j].thread_elements();
218 int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
219
220 // Apply epilogue
221 accum[0] = epilogue_op.apply(accum[0], C[offset_c]);
222 accum[1] = epilogue_op.apply(accum[1], C[offset_c + fdc]);
223 }
224 }
225 }
226
227 /* Apply epilogue */
228 template <typename BinaryEpilogue>
229 METAL_FUNC void apply_epilogue_safe(
230 const device U* C,
231 const int ldc,
232 const int fdc,
233 short2 dst_tile_dims,
234 thread const BinaryEpilogue& epilogue_op) {
235 // Adjust for simdgroup and thread location
236 C += (sm + tm) * ldc + (tn + sn) * fdc;
237 dst_tile_dims -= short2(tn + sn, sm + tm);
238
239 if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
240 return;
241
242 // Loop over all simdgroup tiles
244 for (short i = 0; i < TM; i++) {
246 for (short j = 0; j < TN; j++) {
247 // Get accumulated result and associated offset in C
248 thread auto& accum = results[i * TN + j].thread_elements();
249 int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
250
251 // Read C
252 U c_elems[2] = {0};
253
254 if ((j * TN_stride + 1) < dst_tile_dims.x) {
255 c_elems[0] = C[offset_c];
256 c_elems[1] = C[offset_c + fdc];
257 } else if ((j * TN_stride) < dst_tile_dims.x) {
258 c_elems[0] = C[offset_c];
259 }
260
261 // Apply epilogue
262 accum[0] = epilogue_op.apply(accum[0], c_elems[0]);
263 accum[1] = epilogue_op.apply(accum[1], c_elems[1]);
264 }
265 }
266 }
267
268 /* Store results from simdgroup_matrix results into device memory */
269 METAL_FUNC void store_result(
270 device U* D,
271 const int ldd,
272 const device U* C,
273 const int ldc,
274 const int fdc,
275 thread const Epilogue& epilogue_op) const {
276 // Adjust for simdgroup and thread location
277 C += (sm + tm) * ldc + (tn + sn) * fdc;
278 D += (sm + tm) * ldd + tn + sn;
279
280 // Loop over all simdgroup tiles
282 for (short i = 0; i < TM; i++) {
284 for (short j = 0; j < TN; j++) {
285 // Get accumulated result and associated offset in C
286 thread const auto& accum = results[i * TN + j].thread_elements();
287 int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
288 int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
289
290 // Apply epilogue
291 U outs[2] = {
292 epilogue_op.apply(accum[0], C[offset_c]),
293 epilogue_op.apply(accum[1], C[offset_c + fdc])};
294
295 // Write out D
296 D[offset_d] = outs[0];
297 D[offset_d + 1] = outs[1];
298 }
299 }
300 }
301
302 METAL_FUNC void store_result_safe(
303 device U* D,
304 const int ldd,
305 const device U* C,
306 const int ldc,
307 const int fdc,
308 short2 dst_tile_dims,
309 thread const Epilogue& epilogue_op) const {
310 // Adjust for simdgroup and thread location
311 C += (sm + tm) * ldc + (tn + sn) * fdc;
312 D += (sm + tm) * ldd + tn + sn;
313 dst_tile_dims -= short2(tn + sn, sm + tm);
314
315 if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
316 return;
317
319 for (int i = 0; i < TM; i++) {
320 if (i * TM_stride < dst_tile_dims.y) {
322 for (int j = 0; j < TN; j++) {
323 // Get accumulated result and associated offset in C
324 thread const auto& accum = results[i * TN + j].thread_elements();
325 int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
326 int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
327
328 // Apply epilogue and output C
329 if (j * TN_stride < dst_tile_dims.x) {
330 D[offset_d] = epilogue_op.apply(accum[0], C[offset_c]);
331 }
332
333 if (j * TN_stride + 1 < dst_tile_dims.x) {
334 D[offset_d + 1] = epilogue_op.apply(accum[1], C[offset_c + fdc]);
335 }
336 }
337 }
338 }
339 }
340};
341
342} // namespace steel
343} // namespace mlx
#define STEEL_PRAGMA_UNROLL
Definition utils.h:8
#define STEEL_CONST
Definition utils.h:7
Definition bf16.h:265
Definition allocator.h:7
Definition mma.h:35
short As_offset
Definition mma.h:72
STEEL_CONST short jump_b
Definition mma.h:54
simdgroup_matrix< AccumType, 8, 8 > Bsimd[TN]
Definition mma.h:61
STEEL_CONST short TM_stride
Definition mma.h:37
METAL_FUNC void mma(const threadgroup T *As, const threadgroup T *Bs)
Definition mma.h:93
simdgroup_matrix< AccumType, 8, 8 > results[TM *TN]
Definition mma.h:62
STEEL_CONST short TN
Definition mma.h:44
METAL_FUNC void store_result_safe(device U *D, const int ldd, const device U *C, const int ldc, const int fdc, short2 dst_tile_dims, thread const Epilogue &epilogue_op) const
Definition mma.h:302
METAL_FUNC void store_result(device U *D, const int ldd, const device U *C, const int ldc, const int fdc, thread const Epilogue &epilogue_op) const
Definition mma.h:269
METAL_FUNC void apply_epilogue(const device U *C, const int ldc, const int fdc, thread const BinaryEpilogue &epilogue_op)
Definition mma.h:203
METAL_FUNC void store_result(device U *D, const int ldd) const
Definition mma.h:147
STEEL_CONST short TN_stride
Definition mma.h:39
STEEL_CONST short tile_stride_a
Definition mma.h:56
simdgroup_matrix< AccumType, 8, 8 > Asimd[TM]
Definition mma.h:60
short Bs_offset
Definition mma.h:73
METAL_FUNC void apply_epilogue_safe(const device U *C, const int ldc, const int fdc, short2 dst_tile_dims, thread const BinaryEpilogue &epilogue_op)
Definition mma.h:229
METAL_FUNC BlockMMA(ushort simd_group_id, ushort simd_lane_id)
Definition mma.h:76
short sm
Definition mma.h:69
STEEL_CONST short simd_stride_a
Definition mma.h:47
const short tm
Definition mma.h:66
STEEL_CONST short TM
Definition mma.h:42
const short tn
Definition mma.h:67
STEEL_CONST short jump_a
Definition mma.h:53
short sn
Definition mma.h:70
STEEL_CONST short tile_stride_b
Definition mma.h:57
STEEL_CONST short simd_stride_b
Definition mma.h:49
METAL_FUNC void store_result_safe(device U *D, const int ldd, short2 dst_tile_dims) const
Definition mma.h:171