MLX
Loading...
Searching...
No Matches
mma.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3#pragma once
4
5#include <metal_simdgroup>
6#include <metal_simdgroup_matrix>
7#include <metal_stdlib>
8
11
12using namespace metal;
13
15// MMA helper
17
18namespace mlx {
19namespace steel {
20
21template <
22 typename T,
23 typename U,
24 int BM,
25 int BN,
26 int BK,
27 int WM,
28 int WN,
29 bool transpose_a,
30 bool transpose_b,
31 short lda_tgp,
32 short ldb_tgp,
33 typename AccumType = float,
34 typename Epilogue = TransformNone<U, AccumType>>
35struct BlockMMA {
36 // Warp tile simdgroup matrix strides along M
37 STEEL_CONST short TM_stride = 8 * WM;
38 // Warp tile simdgroup matrix strides along M
39 STEEL_CONST short TN_stride = 8 * WN;
40
41 // Warp tile size along M
42 STEEL_CONST short TM = BM / TM_stride;
43 // Warp tile size along N
44 STEEL_CONST short TN = BN / TN_stride;
45
46 // Strides of A, B along reduction axis
48 transpose_a ? TM_stride : TM_stride * lda_tgp};
50 transpose_b ? TN_stride * ldb_tgp : TN_stride};
51
52 // Jump between elements
53 STEEL_CONST short jump_a = {transpose_a ? lda_tgp : 1};
54 STEEL_CONST short jump_b = {transpose_b ? ldb_tgp : 1};
55
56 STEEL_CONST short tile_stride_a = {transpose_a ? 8 * lda_tgp : 8};
57 STEEL_CONST short tile_stride_b = {transpose_b ? 8 : 8 * ldb_tgp};
58
59 // Simdgroup matrices
60 simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
61 simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
62 simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
63 simdgroup_matrix<AccumType, 8, 8>(0)};
64
65 // Offsets within threadgroup
66 const short tm;
67 const short tn;
68
69 short sm;
70 short sn;
71
72 short As_offset;
73 short Bs_offset;
74
75 /* Constructor */
76 METAL_FUNC BlockMMA(
77 ushort simd_group_id [[simdgroup_index_in_threadgroup]],
78 ushort simd_lane_id [[thread_index_in_simdgroup]])
79 : tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
80 // Determine thread position in simdgroup matrix
81 short qid = simd_lane_id / 4;
82 sm = (qid & 4) + (simd_lane_id / 2) % 4;
83 sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
84
85 // Determine thread and simdgroup offset
86 As_offset =
87 transpose_a ? ((sn)*lda_tgp + (tm + sm)) : ((sn) + (tm + sm) * lda_tgp);
88 Bs_offset =
89 transpose_b ? ((tn + sn) * ldb_tgp + (sm)) : ((sm)*ldb_tgp + (tn + sn));
90 }
91
92 /* (BM, BK) X (BK, BN) multiply accumulate function */
93 METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
94 // Adjust for simdgroup and thread location
95 As += As_offset;
96 Bs += Bs_offset;
97
98 // Iterate over BK in blocks of 8
100 for (short kk = 0; kk < BK; kk += 8) {
101 simdgroup_barrier(mem_flags::mem_none);
102
103 // Load elements from threadgroup A as simdgroup matrices
105 for (short i = 0; i < TM; i++) {
106 Asimd[i].thread_elements()[0] =
107 static_cast<AccumType>(As[i * simd_stride_a + 0]);
108 Asimd[i].thread_elements()[1] =
109 static_cast<AccumType>(As[i * simd_stride_a + jump_a]);
110 }
111
112 simdgroup_barrier(mem_flags::mem_none);
113
114 // Load elements from threadgroup B as simdgroup matrices
116 for (short j = 0; j < TN; j++) {
117 Bsimd[j].thread_elements()[0] =
118 static_cast<AccumType>(Bs[j * simd_stride_b + 0]);
119 Bsimd[j].thread_elements()[1] =
120 static_cast<AccumType>(Bs[j * simd_stride_b + jump_b]);
121 }
122
123 simdgroup_barrier(mem_flags::mem_none);
124
125 // Multiply and accumulate into result simdgroup matrices
127 for (short i = 0; i < TM; i++) {
129 for (short j = 0; j < TN; j++) {
130 short j_serp = (i % 2) ? (TN - 1 - j) : j;
131
132 simdgroup_multiply_accumulate(
133 results[i * TN + j_serp],
134 Asimd[i],
135 Bsimd[j_serp],
136 results[i * TN + j_serp]);
137 }
138 }
139
140 // Progress to next simdgroup tile
141 As += tile_stride_a;
142 Bs += tile_stride_b;
143 }
144 }
145
146 /* Store results from simdgroup_matrix results into device memory */
147 METAL_FUNC void store_result(device U* D, const int ldd) const {
148 // Adjust for simdgroup and thread location
149 D += (sm + tm) * ldd + tn + sn;
150
151 // Loop over all simdgroup tiles
153 for (short i = 0; i < TM; i++) {
155 for (short j = 0; j < TN; j++) {
156 // Get accumulated result and associated offset in C
157 thread const auto& accum = results[i * TN + j].thread_elements();
158 int offset = (i * TM_stride) * ldd + (j * TN_stride);
159
160 // Apply epilogue
161 U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])};
162
163 // Write out D
164 D[offset] = outs[0];
165 D[offset + 1] = outs[1];
166 }
167 }
168 }
169
170 METAL_FUNC void
171 store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) const {
172 // Adjust for simdgroup and thread location
173 D += (sm + tm) * ldd + (tn + sn);
174 dst_tile_dims -= short2(tn + sn, sm + tm);
175
176 if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
177 return;
178
180 for (int i = 0; i < TM; i++) {
181 if (i * TM_stride < dst_tile_dims.y) {
183 for (int j = 0; j < TN; j++) {
184 // Get accumulated result and associated offset in C
185 thread const auto& accum = results[i * TN + j].thread_elements();
186 int offset = (i * TM_stride) * ldd + (j * TN_stride);
187
188 // Apply epilogue and output C
189 if (j * TN_stride < dst_tile_dims.x) {
190 D[offset] = Epilogue::apply(accum[0]);
191 }
192
193 if (j * TN_stride + 1 < dst_tile_dims.x) {
194 D[offset + 1] = Epilogue::apply(accum[1]);
195 }
196 }
197 }
198 }
199 }
200
201 /* Store results from simdgroup_matrix results into device memory */
202 METAL_FUNC void store_result(
203 device U* D,
204 const int ldd,
205 const device U* C,
206 const int ldc,
207 const int fdc,
208 thread const Epilogue& epilogue_op) const {
209 // Adjust for simdgroup and thread location
210 C += (sm + tm) * ldc + (tn + sn) * fdc;
211 D += (sm + tm) * ldd + tn + sn;
212
213 // Loop over all simdgroup tiles
215 for (short i = 0; i < TM; i++) {
217 for (short j = 0; j < TN; j++) {
218 // Get accumulated result and associated offset in C
219 thread const auto& accum = results[i * TN + j].thread_elements();
220 int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
221 int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
222
223 // Apply epilogue
224 U outs[2] = {
225 epilogue_op.apply(accum[0], C[offset_c]),
226 epilogue_op.apply(accum[1], C[offset_c + fdc])};
227
228 // Write out D
229 D[offset_d] = outs[0];
230 D[offset_d + 1] = outs[1];
231 }
232 }
233 }
234
235 METAL_FUNC void store_result_safe(
236 device U* D,
237 const int ldd,
238 const device U* C,
239 const int ldc,
240 const int fdc,
241 short2 dst_tile_dims,
242 thread const Epilogue& epilogue_op) const {
243 // Adjust for simdgroup and thread location
244 C += (sm + tm) * ldc + (tn + sn) * fdc;
245 D += (sm + tm) * ldd + tn + sn;
246 dst_tile_dims -= short2(tn + sn, sm + tm);
247
248 if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
249 return;
250
252 for (int i = 0; i < TM; i++) {
253 if (i * TM_stride < dst_tile_dims.y) {
255 for (int j = 0; j < TN; j++) {
256 // Get accumulated result and associated offset in C
257 thread const auto& accum = results[i * TN + j].thread_elements();
258 int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
259 int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
260
261 // Apply epilogue and output C
262 if (j * TN_stride < dst_tile_dims.x) {
263 D[offset_d] = epilogue_op.apply(accum[0], C[offset_c]);
264 }
265
266 if (j * TN_stride + 1 < dst_tile_dims.x) {
267 D[offset_d + 1] = epilogue_op.apply(accum[1], C[offset_c + fdc]);
268 }
269 }
270 }
271 }
272 }
273};
274
275} // namespace steel
276} // namespace mlx
#define STEEL_PRAGMA_UNROLL
Definition utils.h:8
#define STEEL_CONST
Definition utils.h:7
Definition bf16.h:265
Definition allocator.h:7
Definition mma.h:35
short As_offset
Definition mma.h:72
STEEL_CONST short jump_b
Definition mma.h:54
simdgroup_matrix< AccumType, 8, 8 > Bsimd[TN]
Definition mma.h:61
STEEL_CONST short TM_stride
Definition mma.h:37
METAL_FUNC void mma(const threadgroup T *As, const threadgroup T *Bs)
Definition mma.h:93
simdgroup_matrix< AccumType, 8, 8 > results[TM *TN]
Definition mma.h:62
STEEL_CONST short TN
Definition mma.h:44
METAL_FUNC void store_result_safe(device U *D, const int ldd, const device U *C, const int ldc, const int fdc, short2 dst_tile_dims, thread const Epilogue &epilogue_op) const
Definition mma.h:235
METAL_FUNC void store_result(device U *D, const int ldd, const device U *C, const int ldc, const int fdc, thread const Epilogue &epilogue_op) const
Definition mma.h:202
METAL_FUNC void store_result(device U *D, const int ldd) const
Definition mma.h:147
STEEL_CONST short TN_stride
Definition mma.h:39
STEEL_CONST short tile_stride_a
Definition mma.h:56
simdgroup_matrix< AccumType, 8, 8 > Asimd[TM]
Definition mma.h:60
short Bs_offset
Definition mma.h:73
METAL_FUNC BlockMMA(ushort simd_group_id, ushort simd_lane_id)
Definition mma.h:76
short sm
Definition mma.h:69
STEEL_CONST short simd_stride_a
Definition mma.h:47
const short tm
Definition mma.h:66
STEEL_CONST short TM
Definition mma.h:42
const short tn
Definition mma.h:67
STEEL_CONST short jump_a
Definition mma.h:53
short sn
Definition mma.h:70
STEEL_CONST short tile_stride_b
Definition mma.h:57
STEEL_CONST short simd_stride_b
Definition mma.h:49
METAL_FUNC void store_result_safe(device U *D, const int ldd, short2 dst_tile_dims) const
Definition mma.h:171