MLX
 
Loading...
Searching...
No Matches
steel_gemm_splitk.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3using namespace mlx::steel;
4
6// GEMM kernels
8
9template <
10 typename T,
11 typename U,
12 int BM,
13 int BN,
14 int BK,
15 int WM,
16 int WN,
17 bool transpose_a,
18 bool transpose_b,
19 bool MN_aligned,
20 bool K_aligned>
21[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm_splitk(
22 const device T* A [[buffer(0)]],
23 const device T* B [[buffer(1)]],
24 device U* C [[buffer(2)]],
25 const constant GEMMSpiltKParams* params [[buffer(3)]],
26 uint simd_lane_id [[thread_index_in_simdgroup]],
27 uint simd_group_id [[simdgroup_index_in_threadgroup]],
28 uint3 tid [[threadgroup_position_in_grid]],
29 uint3 lid [[thread_position_in_threadgroup]]) {
30 (void)lid;
31
32 using gemm_kernel = GEMMKernel<
33 T,
34 U,
35 BM,
36 BN,
37 BK,
38 WM,
39 WN,
40 transpose_a,
41 transpose_b,
42 MN_aligned,
43 K_aligned>;
44 using loader_a_t = typename gemm_kernel::loader_a_t;
45 using loader_b_t = typename gemm_kernel::loader_b_t;
46 using mma_t = typename gemm_kernel::mma_t;
47
48 threadgroup T As[gemm_kernel::tgp_mem_size_a];
49 threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
50
51 const int tid_x = tid.x;
52 const int tid_y = tid.y;
53 const int tid_z = tid.z;
54
55 if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
56 return;
57 }
58
59 // Find block in A, B, C
60 const int c_row = tid_y * BM;
61 const int c_col = tid_x * BN;
62 const int k_start = params->split_k_partition_size * tid_z;
63
64 const size_t c_row_long = size_t(c_row);
65 const size_t c_col_long = size_t(c_col);
66 const size_t k_start_long = size_t(k_start);
67
68 A += transpose_a ? (c_row_long + k_start_long * params->lda)
69 : (k_start_long + c_row_long * params->lda);
70 B += transpose_b ? (k_start_long + c_col_long * params->ldb)
71 : (c_col_long + k_start_long * params->ldb);
72 C += (size_t(params->split_k_partition_stride) * tid_z) +
73 (c_row_long * params->ldc + c_col_long);
74
75 // Prepare threadgroup loading operations
76 thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
77 thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
78
79 // Prepare threadgroup mma operation
80 thread mma_t mma_op(simd_group_id, simd_lane_id);
81
82 int gemm_k_iterations = params->gemm_k_iterations_aligned;
83
84 short tgp_bm = min(BM, params->M - c_row);
85 short tgp_bn = min(BN, params->N - c_col);
86 short leftover_bk = params->K % BK;
87
88 if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
89 gemm_kernel::gemm_loop(
90 As,
91 Bs,
92 gemm_k_iterations,
93 loader_a,
94 loader_b,
95 mma_op,
96 tgp_bm,
97 tgp_bn,
98 leftover_bk,
100 } else if (tgp_bn == BN) {
101 gemm_kernel::gemm_loop(
102 As,
103 Bs,
104 gemm_k_iterations,
105 loader_a,
106 loader_b,
107 mma_op,
108 tgp_bm,
109 tgp_bn,
110 leftover_bk,
112 } else if (tgp_bm == BM) {
113 gemm_kernel::gemm_loop(
114 As,
115 Bs,
116 gemm_k_iterations,
117 loader_a,
118 loader_b,
119 mma_op,
120 tgp_bm,
121 tgp_bn,
122 leftover_bk,
124 } else {
125 gemm_kernel::gemm_loop(
126 As,
127 Bs,
128 gemm_k_iterations,
129 loader_a,
130 loader_b,
131 mma_op,
132 tgp_bm,
133 tgp_bn,
134 leftover_bk,
136 }
137
138 threadgroup_barrier(mem_flags::mem_threadgroup);
139
140 if ((tid_z + 1) == (params->split_k_partitions)) {
141 int gemm_k_iter_remaining =
142 (params->K - (k_start + params->split_k_partition_size)) / BK;
143 if (!K_aligned || gemm_k_iter_remaining > 0)
144 gemm_kernel::gemm_loop(
145 As,
146 Bs,
147 gemm_k_iter_remaining,
148 loader_a,
149 loader_b,
150 mma_op,
151 tgp_bm,
152 tgp_bn,
153 leftover_bk,
155 }
156
157 if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
158 mma_op.store_result(C, params->ldc);
159 } else {
160 mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
161 }
162}
163
165// Split k accumulation kernel
167
168template <
169 typename AccT,
170 typename OutT,
171 typename Epilogue = TransformNone<OutT, AccT>>
172[[kernel]] void gemm_splitk_accum(
173 const device AccT* C_split [[buffer(0)]],
174 device OutT* D [[buffer(1)]],
175 const constant int& k_partitions [[buffer(2)]],
176 const constant int& partition_stride [[buffer(3)]],
177 const constant int& ldd [[buffer(4)]],
178 uint2 gid [[thread_position_in_grid]]) {
179 // Ajust D and C
180 D += gid.x + gid.y * size_t(ldd);
181 C_split += gid.x + gid.y * size_t(ldd);
182
183 size_t offset = 0;
184 AccT out = 0;
185
186 for (int i = 0; i < k_partitions; i++) {
187 out += C_split[offset];
188 offset += partition_stride;
189 }
190
191 // Write output
192 D[0] = Epilogue::apply(out);
193}
194
195template <
196 typename AccT,
197 typename OutT,
198 typename Epilogue = TransformAxpby<OutT, AccT>>
200 const device AccT* C_split [[buffer(0)]],
201 device OutT* D [[buffer(1)]],
202 const constant int& k_partitions [[buffer(2)]],
203 const constant int& partition_stride [[buffer(3)]],
204 const constant int& ldd [[buffer(4)]],
205 const device OutT* C [[buffer(5)]],
206 const constant int& ldc [[buffer(6)]],
207 const constant int& fdc [[buffer(7)]],
208 const constant float& alpha [[buffer(8)]],
209 const constant float& beta [[buffer(9)]],
210 uint2 gid [[thread_position_in_grid]]) {
211 // Ajust D and C
212 C += gid.x * size_t(fdc) + gid.y * size_t(ldc);
213 D += gid.x + gid.y * size_t(ldd);
214 C_split += gid.x + gid.y * size_t(ldd);
215
216 size_t offset = 0;
217 AccT out = 0;
218
219 for (int i = 0; i < k_partitions; i++) {
220 out += C_split[offset];
221 offset += partition_stride;
222 }
223
224 // Write output
225 Epilogue op(alpha, beta);
226 D[0] = op.apply(out, *C);
227}
Definition attn.h:19
void gemm_splitk(const device T *A, const device T *B, device U *C, const constant GEMMSpiltKParams *params, uint simd_lane_id, uint simd_group_id, uint3 tid, uint3 lid)
Definition steel_gemm_splitk.h:21
void gemm_splitk_accum(const device AccT *C_split, device OutT *D, const constant int &k_partitions, const constant int &partition_stride, const constant int &ldd, uint2 gid)
Definition steel_gemm_splitk.h:172
void gemm_splitk_accum_axpby(const device AccT *C_split, device OutT *D, const constant int &k_partitions, const constant int &partition_stride, const constant int &ldd, const device OutT *C, const constant int &ldc, const constant int &fdc, const constant float &alpha, const constant float &beta, uint2 gid)
Definition steel_gemm_splitk.h:199
Definition attn.h:38
Definition params.h:34
Definition attn.h:22
Definition transforms.h:39
Definition transforms.h:15