mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 10:48:09 +08:00
add bfloat conv for windograd (#1306)
* add bfloat conv for windograd * accumulate in fp32 * accumulate in fp32 * accumulate in bf16
This commit is contained in:
@@ -344,12 +344,12 @@ winograd_conv_2d_weight_transform(
|
||||
const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
||||
|
||||
// Initialize G matrix
|
||||
simdgroup_matrix<T, 8, 8> G;
|
||||
simdgroup_matrix<float, 8, 8> G;
|
||||
G.thread_elements()[0] = WGT::wt_transform[sm][sn];
|
||||
G.thread_elements()[1] = WGT::wt_transform[sm][sn + 1];
|
||||
|
||||
// Initialize Gt matrix
|
||||
simdgroup_matrix<T, 8, 8> Gt;
|
||||
simdgroup_matrix<float, 8, 8> Gt;
|
||||
Gt.thread_elements()[0] = WGT::wt_transform[sn][sm];
|
||||
Gt.thread_elements()[1] = WGT::wt_transform[sn + 1][sm];
|
||||
|
||||
@@ -381,15 +381,15 @@ winograd_conv_2d_weight_transform(
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Do transform and store the result
|
||||
for (int c = 0; c < BC; ++c) {
|
||||
simdgroup_matrix<T, 8, 8> g;
|
||||
simdgroup_matrix<float, 8, 8> g;
|
||||
g.thread_elements()[0] =
|
||||
sm < R && sn < R ? Ws[simd_group_id][sm][sn][c] : T(0);
|
||||
g.thread_elements()[1] =
|
||||
sm < R && sn + 1 < R ? Ws[simd_group_id][sm][sn + 1][c] : T(0);
|
||||
|
||||
simdgroup_matrix<T, 8, 8> g_out = (G * g) * Gt;
|
||||
wt_out_0[c * O] = g_out.thread_elements()[0];
|
||||
wt_out_1[c * O] = g_out.thread_elements()[1];
|
||||
simdgroup_matrix<float, 8, 8> g_out = (G * g) * Gt;
|
||||
wt_out_0[c * O] = static_cast<T>(g_out.thread_elements()[0]);
|
||||
wt_out_1[c * O] = static_cast<T>(g_out.thread_elements()[1]);
|
||||
}
|
||||
|
||||
wt_in += BC;
|
||||
@@ -433,12 +433,12 @@ winograd_conv_2d_input_transform(
|
||||
const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
||||
|
||||
// Initialize B matrix
|
||||
simdgroup_matrix<T, 8, 8> B;
|
||||
simdgroup_matrix<float, 8, 8> B;
|
||||
B.thread_elements()[0] = WGT::in_transform[sm][sn];
|
||||
B.thread_elements()[1] = WGT::in_transform[sm][sn + 1];
|
||||
|
||||
// Initialize Bt matrix
|
||||
simdgroup_matrix<T, 8, 8> Bt;
|
||||
simdgroup_matrix<float, 8, 8> Bt;
|
||||
Bt.thread_elements()[0] = WGT::in_transform[sn][sm];
|
||||
Bt.thread_elements()[1] = WGT::in_transform[sn + 1][sm];
|
||||
|
||||
@@ -493,13 +493,13 @@ winograd_conv_2d_input_transform(
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Do transform and store the result
|
||||
for (int c = simd_group_id; c < BC; c += N_SIMD_GROUPS) {
|
||||
simdgroup_matrix<T, 8, 8> I;
|
||||
simdgroup_matrix<float, 8, 8> I;
|
||||
I.thread_elements()[0] = Is[sm][sn][c];
|
||||
I.thread_elements()[1] = Is[sm][sn + 1][c];
|
||||
|
||||
simdgroup_matrix<T, 8, 8> I_out = (Bt * I) * B;
|
||||
inp_out_0[c] = I_out.thread_elements()[0];
|
||||
inp_out_1[c] = I_out.thread_elements()[1];
|
||||
simdgroup_matrix<float, 8, 8> I_out = (Bt * I) * B;
|
||||
inp_out_0[c] = static_cast<T>(I_out.thread_elements()[0]);
|
||||
inp_out_1[c] = static_cast<T>(I_out.thread_elements()[1]);
|
||||
}
|
||||
|
||||
inp_in += BC;
|
||||
@@ -543,12 +543,12 @@ winograd_conv_2d_output_transform(
|
||||
const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
||||
|
||||
// Initialize A matrix
|
||||
simdgroup_matrix<T, 8, 8> B;
|
||||
simdgroup_matrix<float, 8, 8> B;
|
||||
B.thread_elements()[0] = WGT::out_transform[sm][sn];
|
||||
B.thread_elements()[1] = WGT::out_transform[sm][sn + 1];
|
||||
|
||||
// Initialize At matrix
|
||||
simdgroup_matrix<T, 8, 8> Bt;
|
||||
simdgroup_matrix<float, 8, 8> Bt;
|
||||
Bt.thread_elements()[0] = WGT::out_transform[sn][sm];
|
||||
Bt.thread_elements()[1] = WGT::out_transform[sn + 1][sm];
|
||||
|
||||
@@ -597,16 +597,16 @@ winograd_conv_2d_output_transform(
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Do transform and store the result
|
||||
for (int c = simd_group_id; c < BO; c += N_SIMD_GROUPS) {
|
||||
simdgroup_matrix<T, 8, 8> O_mat;
|
||||
simdgroup_matrix<float, 8, 8> O_mat;
|
||||
O_mat.thread_elements()[0] = out_in_0[c];
|
||||
O_mat.thread_elements()[1] = out_in_1[c];
|
||||
|
||||
simdgroup_matrix<T, 8, 8> O_out = (Bt * (O_mat * B));
|
||||
simdgroup_matrix<float, 8, 8> O_out = (Bt * (O_mat * B));
|
||||
if ((sm < M) && (sn < M)) {
|
||||
Os[sm][sn][c] = O_out.thread_elements()[0];
|
||||
Os[sm][sn][c] = static_cast<T>(O_out.thread_elements()[0]);
|
||||
}
|
||||
if ((sm < M) && ((sn + 1) < M)) {
|
||||
Os[sm][sn + 1][c] = O_out.thread_elements()[1];
|
||||
Os[sm][sn + 1][c] = static_cast<T>(O_out.thread_elements()[1]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -650,4 +650,5 @@ winograd_conv_2d_output_transform(
|
||||
|
||||
// clang-format off
|
||||
instantiate_winograd_conv_2d(float32, float);
|
||||
instantiate_winograd_conv_2d(bfloat16, bfloat16_t);
|
||||
instantiate_winograd_conv_2d(float16, half); // clang-format on
|
||||
|
Reference in New Issue
Block a user