MLX
Loading...
Searching...
No Matches
steel_conv_general.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
4
5template <
6 typename T,
7 int BM,
8 int BN,
9 int BK,
10 int WM,
11 int WN,
12 typename AccumType = float,
13 typename Epilogue = TransformNone<T, AccumType>>
14[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
16 const device T* A [[buffer(0)]],
17 const device T* B [[buffer(1)]],
18 device T* C [[buffer(2)]],
19 const constant MLXConvParams<2>* params [[buffer(3)]],
20 const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
21 const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]],
22 const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]],
23 const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]],
24 uint3 tid [[threadgroup_position_in_grid]],
25 uint3 lid [[thread_position_in_threadgroup]],
26 uint simd_gid [[simdgroup_index_in_threadgroup]],
27 uint simd_lid [[thread_index_in_simdgroup]]) {
28 (void)lid;
29
30 constexpr bool transpose_a = false;
31 constexpr bool transpose_b = true;
32 constexpr short tgp_padding_a = 16 / sizeof(T);
33 constexpr short tgp_padding_b = 16 / sizeof(T);
34
35 constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a;
36 constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b;
37 constexpr short shape_a_rows = (transpose_a ? BK : BM);
38 constexpr short shape_b_rows = (transpose_b ? BN : BK);
39 constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows;
40 constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows;
41
42 constexpr short tgp_size = WM * WN * 32;
43
44 // Input loader
45 using loader_a_t =
46 Conv2DInputBlockLoaderGeneral<T, BM, BN, BK, tgp_size, tgp_padding_a>;
47
48 // Weight loader
49 using loader_b_t =
50 Conv2DWeightBlockLoaderGeneral<T, BM, BN, BK, tgp_size, tgp_padding_b>;
51
52 using mma_t = BlockMMA<
53 T,
54 T,
55 BM,
56 BN,
57 BK,
58 WM,
59 WN,
60 transpose_a,
61 transpose_b,
62 shape_a_cols,
63 shape_b_cols>;
64
65 threadgroup T As[tgp_mem_size_a];
66 threadgroup T Bs[tgp_mem_size_b];
67
68 const int tid_y = ((tid.y) << gemm_params->swizzle_log) +
69 ((tid.x) & ((1 << gemm_params->swizzle_log) - 1));
70 const int tid_x = (tid.x) >> gemm_params->swizzle_log;
71
72 if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) {
73 return;
74 }
75
76 const int tid_z = tid.z;
77
78 const int base_oh = tid_z / jump_params->f_out_jump_w;
79 const int base_ow = tid_z % jump_params->f_out_jump_w;
80
81 const int base_wh = base_h[base_oh].weight_base;
82 const int base_ww = base_w[base_ow].weight_base;
83
84 const int base_wh_size = base_h[base_oh].weight_size;
85 const int base_ww_size = base_w[base_ow].weight_size;
86
87 const int c_row = tid_y * BM;
88 const int c_col = tid_x * BN;
89 const int K = gemm_params->K;
90
91 B += c_col * K;
92
93 const int4 offsets_a(0, c_row, base_oh, base_ow);
94 const int2 offsets_b(0, c_col);
95
96 // Prepare threadgroup loading operations
97 loader_a_t loader_a(
98 A,
99 As,
100 offsets_a,
101 params,
102 jump_params,
103 base_wh,
104 base_ww,
105 simd_gid,
106 simd_lid);
107 loader_b_t loader_b(
108 B,
109 Bs,
110 offsets_b,
111 params,
112 jump_params,
113 base_wh,
114 base_ww,
115 simd_gid,
116 simd_lid);
117
118 // Prepare threadgroup mma operation
119 mma_t mma_op(simd_gid, simd_lid);
120
121 int gemm_k_iterations =
122 base_wh_size * base_ww_size * gemm_params->gemm_k_iterations;
123
124 for (int k = 0; k < gemm_k_iterations; k++) {
125 threadgroup_barrier(mem_flags::mem_threadgroup);
126 // Load elements into threadgroup
127 loader_a.load_unsafe();
128 loader_b.load_unsafe();
129
130 threadgroup_barrier(mem_flags::mem_threadgroup);
131
132 // Multiply and accumulate threadgroup elements
133 mma_op.mma(As, Bs);
134
135 // Prepare for next iteration
136 loader_a.next();
137 loader_b.next();
138 }
139
140 threadgroup_barrier(mem_flags::mem_none);
141
142 // Store results to device memory
143 {
144 // Adjust for simdgroup and thread locatio
145 int offset_m = c_row + mma_op.sm + mma_op.tm;
146 int offset_n = c_col + mma_op.sn + mma_op.tn;
147 C += offset_n;
148
149 if (offset_n >= gemm_params->N)
150 return;
151
152 short diff = gemm_params->N - offset_n;
153
155 for (int i = 0; i < mma_t::TM; i++) {
156 int cm = offset_m + i * mma_t::TM_stride;
157
158 int n = cm / jump_params->adj_out_hw;
159 int hw = cm % jump_params->adj_out_hw;
160 int oh =
161 (hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + base_oh;
162 int ow =
163 (hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + base_ow;
164
165 if (n < params->N && oh < params->oS[0] && ow < params->oS[1]) {
166 int offset_cm = n * params->out_strides[0] +
167 oh * params->out_strides[1] + ow * params->out_strides[2];
168
170 for (int j = 0; j < mma_t::TN; j++) {
171 // Get accumulated result and associated offset in C
172 thread const auto& accum =
173 mma_op.results[i * mma_t::TN + j].thread_elements();
174 int offset = offset_cm + (j * mma_t::TN_stride);
175
176 // Apply epilogue and output C
177 if (j * mma_t::TN_stride < diff) {
178 C[offset] = Epilogue::apply(accum[0]);
179 }
180
181 if (j * mma_t::TN_stride + 1 < diff) {
182 C[offset + 1] = Epilogue::apply(accum[1]);
183 }
184 }
185 }
186 }
187 }
188}
#define STEEL_PRAGMA_UNROLL
Definition defines.h:4
void implicit_gemm_conv_2d_general(const device T *A, const device T *B, device T *C, const constant MLXConvParams< 2 > *params, const constant ImplicitGemmConv2DParams *gemm_params, const constant Conv2DGeneralJumpParams *jump_params, const constant Conv2DGeneralBaseInfo *base_h, const constant Conv2DGeneralBaseInfo *base_w, uint3 tid, uint3 lid, uint simd_gid, uint simd_lid)
Definition steel_conv_general.h:15
Definition params.h:6