MLX
 
Loading...
Searching...
No Matches
attn.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3#pragma once
4
11
12using namespace metal;
13
15// GEMM kernel class
17
18namespace mlx {
19namespace steel {
20
21template <bool M_aligned, bool N_aligned, bool K_aligned>
22struct LoopAlignment {};
23
24template <
25 typename T,
26 typename U,
27 int BM,
28 int BN,
29 int BK,
30 int WM,
31 int WN,
32 bool transpose_a,
33 bool transpose_b,
34 bool MN_aligned,
35 bool K_aligned,
36 typename AccumType = typename AccumHelper<T>::accum_type,
37 typename Epilogue = TransformNone<U, AccumType>>
38struct GEMMKernel {
39 STEEL_CONST short tgp_padding_a = 16 / sizeof(T);
40 STEEL_CONST short tgp_padding_b = 16 / sizeof(T);
42 transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
44 transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
46
47 STEEL_CONST short tgp_size = WM * WN * 32;
48
50 T,
51 transpose_a ? BK : BM,
52 transpose_a ? BM : BK,
53 transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
54 !transpose_a,
55 tgp_size>;
57 T,
58 transpose_b ? BN : BK,
59 transpose_b ? BK : BN,
60 transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
61 transpose_b,
62 tgp_size>;
63 using mma_t = BlockMMA<
64 T,
65 U,
66 BM,
67 BN,
68 BK,
69 WM,
70 WN,
71 transpose_a,
72 transpose_b,
73 transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
74 transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
75 AccumType,
76 Epilogue>;
77
78 /* Main kernel function */
79 template <bool M_aligned, bool N_aligned, bool K_aligned_>
80 static METAL_FUNC void gemm_loop(
81 threadgroup T* As [[threadgroup(0)]],
82 threadgroup T* Bs [[threadgroup(1)]],
83 const int gemm_k_iterations,
84 thread loader_a_t& loader_a,
85 thread loader_b_t& loader_b,
86 thread mma_t& mma_op,
87 thread const short& tgp_bm,
88 thread const short& tgp_bn,
89 thread const short& lbk,
91 // Appease the compiler
92 (void)l;
93
94 short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
95
96 short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
97
98 for (int k = 0; k < gemm_k_iterations; k++) {
99 threadgroup_barrier(mem_flags::mem_threadgroup);
100 // Load elements into threadgroup
101 if (M_aligned) {
102 loader_a.load_unsafe();
103 } else {
104 loader_a.load_safe(tile_dims_A);
105 }
106
107 if (N_aligned) {
108 loader_b.load_unsafe();
109 } else {
110 loader_b.load_safe(tile_dims_B);
111 }
112
113 threadgroup_barrier(mem_flags::mem_threadgroup);
114
115 // Multiply and accumulate threadgroup elements
116 mma_op.mma(As, Bs);
117
118 // Prepare for next iteration
119 loader_a.next();
120 loader_b.next();
121 }
122
123 if (!K_aligned_) {
124 threadgroup_barrier(mem_flags::mem_threadgroup);
125
126 short2 tile_dims_A_last =
127 transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
128 short2 tile_dims_B_last =
129 transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
130
131 loader_a.load_safe(tile_dims_A_last);
132 loader_b.load_safe(tile_dims_B_last);
133
134 threadgroup_barrier(mem_flags::mem_threadgroup);
135
136 mma_op.mma(As, Bs);
137 }
138 }
139
140 /* Main kernel function */
141 static METAL_FUNC void run(
142 const device T* A [[buffer(0)]],
143 const device T* B [[buffer(1)]],
144 device U* D [[buffer(2)]],
145 const constant GEMMParams* params [[buffer(3)]],
146 threadgroup T* As [[threadgroup(0)]],
147 threadgroup T* Bs [[threadgroup(1)]],
148 uint simd_lane_id [[thread_index_in_simdgroup]],
149 uint simd_group_id [[simdgroup_index_in_threadgroup]],
150 uint3 tid [[threadgroup_position_in_grid]],
151 uint3 lid [[thread_position_in_threadgroup]]) {
152 // Pacifying compiler
153 (void)lid;
154
155 const int tid_y = ((tid.y) << params->swizzle_log) +
156 ((tid.x) & ((1 << params->swizzle_log) - 1));
157 const int tid_x = (tid.x) >> params->swizzle_log;
158
159 if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
160 return;
161 }
162
163 threadgroup_barrier(mem_flags::mem_none);
164
165 // Find block in A, B, C
166 const int c_row = tid_y * BM;
167 const int c_col = tid_x * BN;
168 const size_t c_row_long = size_t(c_row);
169 const size_t c_col_long = size_t(c_col);
170
171 A += transpose_a ? c_row_long : c_row_long * params->lda;
172 B += transpose_b ? c_col_long * params->ldb : c_col_long;
173 D += c_row_long * params->ldd + c_col_long;
174
175 // Prepare threadgroup loading operations
176 thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
177 thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
178
179 // Prepare threadgroup mma operation
180 thread mma_t mma_op(simd_group_id, simd_lane_id);
181
182 int gemm_k_iterations = params->gemm_k_iterations_aligned;
183
185 // MNK aligned loop
186 if (MN_aligned) {
187 for (int k = 0; k < gemm_k_iterations; k++) {
188 threadgroup_barrier(mem_flags::mem_threadgroup);
189 // Load elements into threadgroup
190 loader_a.load_unsafe();
191 loader_b.load_unsafe();
192
193 threadgroup_barrier(mem_flags::mem_threadgroup);
194
195 // Multiply and accumulate threadgroup elements
196 mma_op.mma(As, Bs);
197
198 // Prepare for next iteration
199 loader_a.next();
200 loader_b.next();
201 }
202
203 threadgroup_barrier(mem_flags::mem_none);
204
205 // Loop tail
206 if (!K_aligned) {
207 int lbk = params->K - params->gemm_k_iterations_aligned * BK;
208 short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
209 short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
210
211 loader_a.load_safe(tile_dims_A);
212 loader_b.load_safe(tile_dims_B);
213
214 threadgroup_barrier(mem_flags::mem_threadgroup);
215
216 mma_op.mma(As, Bs);
217 }
218
219 // Store results to device memory
220 mma_op.store_result(D, params->ldd);
221 return;
222
223 }
225 // MN unaligned loop
226 else { // Loop over K - unaligned case
227 short tgp_bm = min(BM, params->M - c_row);
228 short tgp_bn = min(BN, params->N - c_col);
229 short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
230
231 if (tgp_bm == BM && tgp_bn == BN) {
233 As,
234 Bs,
235 gemm_k_iterations,
236 loader_a,
237 loader_b,
238 mma_op,
239 tgp_bm,
240 tgp_bn,
241 leftover_bk);
242
243 mma_op.store_result(D, params->ldd);
244 return;
245
246 } else if (tgp_bn == BN) {
248 As,
249 Bs,
250 gemm_k_iterations,
251 loader_a,
252 loader_b,
253 mma_op,
254 tgp_bm,
255 tgp_bn,
256 leftover_bk);
257
258 mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
259 return;
260
261 } else if (tgp_bm == BM) {
263 As,
264 Bs,
265 gemm_k_iterations,
266 loader_a,
267 loader_b,
268 mma_op,
269 tgp_bm,
270 tgp_bn,
271 leftover_bk);
272
273 mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
274 return;
275
276 } else {
278 As,
279 Bs,
280 gemm_k_iterations,
281 loader_a,
282 loader_b,
283 mma_op,
284 tgp_bm,
285 tgp_bn,
286 leftover_bk);
287
288 mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
289 return;
290 }
291 }
292 }
293};
294
295} // namespace steel
296} // namespace mlx
Definition bf16_math.h:226
METAL_FUNC bfloat16_t min(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:232
Definition attn.h:19
Definition allocator.h:7
#define STEEL_CONST
Definition defines.h:3
float accum_type
Definition transforms.h:57
Definition loader.h:25
Definition mma.h:449
Definition attn.h:38
static METAL_FUNC void run(const device T *A, const device T *B, device U *D, const constant GEMMParams *params, threadgroup T *As, threadgroup T *Bs, uint simd_lane_id, uint simd_group_id, uint3 tid, uint3 lid)
Definition attn.h:141
STEEL_CONST short tgp_mem_size_b
Definition attn.h:43
STEEL_CONST short tgp_mem_size
Definition attn.h:45
static METAL_FUNC void gemm_loop(threadgroup T *As, threadgroup T *Bs, const int gemm_k_iterations, thread loader_a_t &loader_a, thread loader_b_t &loader_b, thread mma_t &mma_op, thread const short &tgp_bm, thread const short &tgp_bn, thread const short &lbk, LoopAlignment< M_aligned, N_aligned, K_aligned_ > l={})
Definition attn.h:80
STEEL_CONST short tgp_size
Definition attn.h:47
BlockLoader< T, transpose_a ? BK :BM, transpose_a ? BM :BK, transpose_a ? BM+tgp_padding_a :BK+tgp_padding_a, !transpose_a, tgp_size > loader_a_t
Definition attn.h:49
BlockLoader< T, transpose_b ? BN :BK, transpose_b ? BK :BN, transpose_b ? BK+tgp_padding_b :BN+tgp_padding_b, transpose_b, tgp_size > loader_b_t
Definition attn.h:56
STEEL_CONST short tgp_mem_size_a
Definition attn.h:41
STEEL_CONST short tgp_padding_b
Definition attn.h:40
STEEL_CONST short tgp_padding_a
Definition attn.h:39
BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, transpose_a ? BM+tgp_padding_a :BK+tgp_padding_a, transpose_b ? BK+tgp_padding_b :BN+tgp_padding_b, AccumType, Epilogue > mma_t
Definition attn.h:63
Definition params.h:12
Definition attn.h:22
Definition transforms.h:15