add bfloat conv for windograd (#1306)

* add bfloat conv for windograd

* accumulate in fp32

* accumulate in fp32

* accumulate in bf16
This commit is contained in:
Awni Hannun 2024-08-05 15:51:13 -07:00 committed by GitHub
parent 10b5835501
commit 58d0e199e1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 40 additions and 29 deletions

View File

@ -344,12 +344,12 @@ winograd_conv_2d_weight_transform(
const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
// Initialize G matrix // Initialize G matrix
simdgroup_matrix<T, 8, 8> G; simdgroup_matrix<float, 8, 8> G;
G.thread_elements()[0] = WGT::wt_transform[sm][sn]; G.thread_elements()[0] = WGT::wt_transform[sm][sn];
G.thread_elements()[1] = WGT::wt_transform[sm][sn + 1]; G.thread_elements()[1] = WGT::wt_transform[sm][sn + 1];
// Initialize Gt matrix // Initialize Gt matrix
simdgroup_matrix<T, 8, 8> Gt; simdgroup_matrix<float, 8, 8> Gt;
Gt.thread_elements()[0] = WGT::wt_transform[sn][sm]; Gt.thread_elements()[0] = WGT::wt_transform[sn][sm];
Gt.thread_elements()[1] = WGT::wt_transform[sn + 1][sm]; Gt.thread_elements()[1] = WGT::wt_transform[sn + 1][sm];
@ -381,15 +381,15 @@ winograd_conv_2d_weight_transform(
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
// Do transform and store the result // Do transform and store the result
for (int c = 0; c < BC; ++c) { for (int c = 0; c < BC; ++c) {
simdgroup_matrix<T, 8, 8> g; simdgroup_matrix<float, 8, 8> g;
g.thread_elements()[0] = g.thread_elements()[0] =
sm < R && sn < R ? Ws[simd_group_id][sm][sn][c] : T(0); sm < R && sn < R ? Ws[simd_group_id][sm][sn][c] : T(0);
g.thread_elements()[1] = g.thread_elements()[1] =
sm < R && sn + 1 < R ? Ws[simd_group_id][sm][sn + 1][c] : T(0); sm < R && sn + 1 < R ? Ws[simd_group_id][sm][sn + 1][c] : T(0);
simdgroup_matrix<T, 8, 8> g_out = (G * g) * Gt; simdgroup_matrix<float, 8, 8> g_out = (G * g) * Gt;
wt_out_0[c * O] = g_out.thread_elements()[0]; wt_out_0[c * O] = static_cast<T>(g_out.thread_elements()[0]);
wt_out_1[c * O] = g_out.thread_elements()[1]; wt_out_1[c * O] = static_cast<T>(g_out.thread_elements()[1]);
} }
wt_in += BC; wt_in += BC;
@ -433,12 +433,12 @@ winograd_conv_2d_input_transform(
const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
// Initialize B matrix // Initialize B matrix
simdgroup_matrix<T, 8, 8> B; simdgroup_matrix<float, 8, 8> B;
B.thread_elements()[0] = WGT::in_transform[sm][sn]; B.thread_elements()[0] = WGT::in_transform[sm][sn];
B.thread_elements()[1] = WGT::in_transform[sm][sn + 1]; B.thread_elements()[1] = WGT::in_transform[sm][sn + 1];
// Initialize Bt matrix // Initialize Bt matrix
simdgroup_matrix<T, 8, 8> Bt; simdgroup_matrix<float, 8, 8> Bt;
Bt.thread_elements()[0] = WGT::in_transform[sn][sm]; Bt.thread_elements()[0] = WGT::in_transform[sn][sm];
Bt.thread_elements()[1] = WGT::in_transform[sn + 1][sm]; Bt.thread_elements()[1] = WGT::in_transform[sn + 1][sm];
@ -493,13 +493,13 @@ winograd_conv_2d_input_transform(
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
// Do transform and store the result // Do transform and store the result
for (int c = simd_group_id; c < BC; c += N_SIMD_GROUPS) { for (int c = simd_group_id; c < BC; c += N_SIMD_GROUPS) {
simdgroup_matrix<T, 8, 8> I; simdgroup_matrix<float, 8, 8> I;
I.thread_elements()[0] = Is[sm][sn][c]; I.thread_elements()[0] = Is[sm][sn][c];
I.thread_elements()[1] = Is[sm][sn + 1][c]; I.thread_elements()[1] = Is[sm][sn + 1][c];
simdgroup_matrix<T, 8, 8> I_out = (Bt * I) * B; simdgroup_matrix<float, 8, 8> I_out = (Bt * I) * B;
inp_out_0[c] = I_out.thread_elements()[0]; inp_out_0[c] = static_cast<T>(I_out.thread_elements()[0]);
inp_out_1[c] = I_out.thread_elements()[1]; inp_out_1[c] = static_cast<T>(I_out.thread_elements()[1]);
} }
inp_in += BC; inp_in += BC;
@ -543,12 +543,12 @@ winograd_conv_2d_output_transform(
const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
// Initialize A matrix // Initialize A matrix
simdgroup_matrix<T, 8, 8> B; simdgroup_matrix<float, 8, 8> B;
B.thread_elements()[0] = WGT::out_transform[sm][sn]; B.thread_elements()[0] = WGT::out_transform[sm][sn];
B.thread_elements()[1] = WGT::out_transform[sm][sn + 1]; B.thread_elements()[1] = WGT::out_transform[sm][sn + 1];
// Initialize At matrix // Initialize At matrix
simdgroup_matrix<T, 8, 8> Bt; simdgroup_matrix<float, 8, 8> Bt;
Bt.thread_elements()[0] = WGT::out_transform[sn][sm]; Bt.thread_elements()[0] = WGT::out_transform[sn][sm];
Bt.thread_elements()[1] = WGT::out_transform[sn + 1][sm]; Bt.thread_elements()[1] = WGT::out_transform[sn + 1][sm];
@ -597,16 +597,16 @@ winograd_conv_2d_output_transform(
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
// Do transform and store the result // Do transform and store the result
for (int c = simd_group_id; c < BO; c += N_SIMD_GROUPS) { for (int c = simd_group_id; c < BO; c += N_SIMD_GROUPS) {
simdgroup_matrix<T, 8, 8> O_mat; simdgroup_matrix<float, 8, 8> O_mat;
O_mat.thread_elements()[0] = out_in_0[c]; O_mat.thread_elements()[0] = out_in_0[c];
O_mat.thread_elements()[1] = out_in_1[c]; O_mat.thread_elements()[1] = out_in_1[c];
simdgroup_matrix<T, 8, 8> O_out = (Bt * (O_mat * B)); simdgroup_matrix<float, 8, 8> O_out = (Bt * (O_mat * B));
if ((sm < M) && (sn < M)) { if ((sm < M) && (sn < M)) {
Os[sm][sn][c] = O_out.thread_elements()[0]; Os[sm][sn][c] = static_cast<T>(O_out.thread_elements()[0]);
} }
if ((sm < M) && ((sn + 1) < M)) { if ((sm < M) && ((sn + 1) < M)) {
Os[sm][sn + 1][c] = O_out.thread_elements()[1]; Os[sm][sn + 1][c] = static_cast<T>(O_out.thread_elements()[1]);
} }
} }
@ -650,4 +650,5 @@ winograd_conv_2d_output_transform(
// clang-format off // clang-format off
instantiate_winograd_conv_2d(float32, float); instantiate_winograd_conv_2d(float32, float);
instantiate_winograd_conv_2d(bfloat16, bfloat16_t);
instantiate_winograd_conv_2d(float16, half); // clang-format on instantiate_winograd_conv_2d(float16, half); // clang-format on

View File

@ -275,7 +275,6 @@ class TestConv(mlx_tests.MLXTestCase):
dilation=(1, 1), dilation=(1, 1),
groups=1, groups=1,
dtype="float32", dtype="float32",
atol=1e-5,
): ):
with self.subTest( with self.subTest(
dtype=dtype, dtype=dtype,
@ -289,19 +288,22 @@ class TestConv(mlx_tests.MLXTestCase):
dilation=dilation, dilation=dilation,
groups=groups, groups=groups,
): ):
np_dtype = getattr(np, dtype)
np.random.seed(0) np.random.seed(0)
iH, iW = idim iH, iW = idim
kH, kW = kdim kH, kW = kdim
scale = 1.0 / math.sqrt(kH * kW * C) scale = 1.0 / math.sqrt(kH * kW * C)
in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype) in_np = np.random.normal(0.0, scale, (N, iH, iW, C))
wt_np = np.random.normal(0.0, 1.0, (O, kH, kW, int(C / groups))).astype( wt_np = np.random.normal(0.0, 1.0, (O, kH, kW, int(C / groups)))
np_dtype
)
in_mx, wt_mx = map(mx.array, (in_np, wt_np)) mx_dtype = getattr(mx, dtype)
torch_dtype = getattr(torch, dtype)
in_mx, wt_mx = map(
lambda x: mx.array(x).astype(mx_dtype), (in_np, wt_np)
)
in_pt, wt_pt = map( in_pt, wt_pt = map(
lambda x: torch.from_numpy(x.transpose(0, 3, 1, 2)).to("cpu"), lambda x: torch.from_numpy(x.transpose(0, 3, 1, 2))
.to("cpu")
.to(torch_dtype),
(in_np, wt_np), (in_np, wt_np),
) )
@ -312,7 +314,7 @@ class TestConv(mlx_tests.MLXTestCase):
padding=padding, padding=padding,
dilation=dilation, dilation=dilation,
groups=groups, groups=groups,
) ).astype(mx.float32)
out_pt = torch.conv2d( out_pt = torch.conv2d(
in_pt, in_pt,
wt_pt, wt_pt,
@ -321,12 +323,20 @@ class TestConv(mlx_tests.MLXTestCase):
dilation=dilation, dilation=dilation,
groups=groups, groups=groups,
) )
out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True) out_pt = (
torch.permute(out_pt, (0, 2, 3, 1))
.to(torch.float32)
.numpy(force=True)
)
self.assertEqual(out_pt.shape, out_mx.shape) self.assertEqual(out_pt.shape, out_mx.shape)
if dtype == "bfloat16":
atol, rtol = 1e-1, 1e-3
else:
atol, rtol = 1e-5, 1e-6
self.assertTrue(np.allclose(out_pt, out_mx, atol=atol)) self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
for dtype in ("float32",): for dtype in ("float32", "bfloat16"):
for N, C, O in ( for N, C, O in (
(1, 1, 1), (1, 1, 1),
(1, 6, 1), (1, 6, 1),