From 58d0e199e15d7146bcfbb10884fb2649affe33b1 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 5 Aug 2024 15:51:13 -0700 Subject: [PATCH] add bfloat conv for windograd (#1306) * add bfloat conv for windograd * accumulate in fp32 * accumulate in fp32 * accumulate in bf16 --- mlx/backend/metal/kernels/conv.metal | 37 ++++++++++++++-------------- python/tests/test_conv.py | 32 +++++++++++++++--------- 2 files changed, 40 insertions(+), 29 deletions(-) diff --git a/mlx/backend/metal/kernels/conv.metal b/mlx/backend/metal/kernels/conv.metal index e67acd93a..fd43aa371 100644 --- a/mlx/backend/metal/kernels/conv.metal +++ b/mlx/backend/metal/kernels/conv.metal @@ -344,12 +344,12 @@ winograd_conv_2d_weight_transform( const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; // Initialize G matrix - simdgroup_matrix G; + simdgroup_matrix G; G.thread_elements()[0] = WGT::wt_transform[sm][sn]; G.thread_elements()[1] = WGT::wt_transform[sm][sn + 1]; // Initialize Gt matrix - simdgroup_matrix Gt; + simdgroup_matrix Gt; Gt.thread_elements()[0] = WGT::wt_transform[sn][sm]; Gt.thread_elements()[1] = WGT::wt_transform[sn + 1][sm]; @@ -381,15 +381,15 @@ winograd_conv_2d_weight_transform( threadgroup_barrier(mem_flags::mem_threadgroup); // Do transform and store the result for (int c = 0; c < BC; ++c) { - simdgroup_matrix g; + simdgroup_matrix g; g.thread_elements()[0] = sm < R && sn < R ? Ws[simd_group_id][sm][sn][c] : T(0); g.thread_elements()[1] = sm < R && sn + 1 < R ? Ws[simd_group_id][sm][sn + 1][c] : T(0); - simdgroup_matrix g_out = (G * g) * Gt; - wt_out_0[c * O] = g_out.thread_elements()[0]; - wt_out_1[c * O] = g_out.thread_elements()[1]; + simdgroup_matrix g_out = (G * g) * Gt; + wt_out_0[c * O] = static_cast(g_out.thread_elements()[0]); + wt_out_1[c * O] = static_cast(g_out.thread_elements()[1]); } wt_in += BC; @@ -433,12 +433,12 @@ winograd_conv_2d_input_transform( const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; // Initialize B matrix - simdgroup_matrix B; + simdgroup_matrix B; B.thread_elements()[0] = WGT::in_transform[sm][sn]; B.thread_elements()[1] = WGT::in_transform[sm][sn + 1]; // Initialize Bt matrix - simdgroup_matrix Bt; + simdgroup_matrix Bt; Bt.thread_elements()[0] = WGT::in_transform[sn][sm]; Bt.thread_elements()[1] = WGT::in_transform[sn + 1][sm]; @@ -493,13 +493,13 @@ winograd_conv_2d_input_transform( threadgroup_barrier(mem_flags::mem_threadgroup); // Do transform and store the result for (int c = simd_group_id; c < BC; c += N_SIMD_GROUPS) { - simdgroup_matrix I; + simdgroup_matrix I; I.thread_elements()[0] = Is[sm][sn][c]; I.thread_elements()[1] = Is[sm][sn + 1][c]; - simdgroup_matrix I_out = (Bt * I) * B; - inp_out_0[c] = I_out.thread_elements()[0]; - inp_out_1[c] = I_out.thread_elements()[1]; + simdgroup_matrix I_out = (Bt * I) * B; + inp_out_0[c] = static_cast(I_out.thread_elements()[0]); + inp_out_1[c] = static_cast(I_out.thread_elements()[1]); } inp_in += BC; @@ -543,12 +543,12 @@ winograd_conv_2d_output_transform( const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; // Initialize A matrix - simdgroup_matrix B; + simdgroup_matrix B; B.thread_elements()[0] = WGT::out_transform[sm][sn]; B.thread_elements()[1] = WGT::out_transform[sm][sn + 1]; // Initialize At matrix - simdgroup_matrix Bt; + simdgroup_matrix Bt; Bt.thread_elements()[0] = WGT::out_transform[sn][sm]; Bt.thread_elements()[1] = WGT::out_transform[sn + 1][sm]; @@ -597,16 +597,16 @@ winograd_conv_2d_output_transform( threadgroup_barrier(mem_flags::mem_threadgroup); // Do transform and store the result for (int c = simd_group_id; c < BO; c += N_SIMD_GROUPS) { - simdgroup_matrix O_mat; + simdgroup_matrix O_mat; O_mat.thread_elements()[0] = out_in_0[c]; O_mat.thread_elements()[1] = out_in_1[c]; - simdgroup_matrix O_out = (Bt * (O_mat * B)); + simdgroup_matrix O_out = (Bt * (O_mat * B)); if ((sm < M) && (sn < M)) { - Os[sm][sn][c] = O_out.thread_elements()[0]; + Os[sm][sn][c] = static_cast(O_out.thread_elements()[0]); } if ((sm < M) && ((sn + 1) < M)) { - Os[sm][sn + 1][c] = O_out.thread_elements()[1]; + Os[sm][sn + 1][c] = static_cast(O_out.thread_elements()[1]); } } @@ -650,4 +650,5 @@ winograd_conv_2d_output_transform( // clang-format off instantiate_winograd_conv_2d(float32, float); +instantiate_winograd_conv_2d(bfloat16, bfloat16_t); instantiate_winograd_conv_2d(float16, half); // clang-format on diff --git a/python/tests/test_conv.py b/python/tests/test_conv.py index b54619671..d5b43b2a2 100644 --- a/python/tests/test_conv.py +++ b/python/tests/test_conv.py @@ -275,7 +275,6 @@ class TestConv(mlx_tests.MLXTestCase): dilation=(1, 1), groups=1, dtype="float32", - atol=1e-5, ): with self.subTest( dtype=dtype, @@ -289,19 +288,22 @@ class TestConv(mlx_tests.MLXTestCase): dilation=dilation, groups=groups, ): - np_dtype = getattr(np, dtype) np.random.seed(0) iH, iW = idim kH, kW = kdim scale = 1.0 / math.sqrt(kH * kW * C) - in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype) - wt_np = np.random.normal(0.0, 1.0, (O, kH, kW, int(C / groups))).astype( - np_dtype - ) + in_np = np.random.normal(0.0, scale, (N, iH, iW, C)) + wt_np = np.random.normal(0.0, 1.0, (O, kH, kW, int(C / groups))) - in_mx, wt_mx = map(mx.array, (in_np, wt_np)) + mx_dtype = getattr(mx, dtype) + torch_dtype = getattr(torch, dtype) + in_mx, wt_mx = map( + lambda x: mx.array(x).astype(mx_dtype), (in_np, wt_np) + ) in_pt, wt_pt = map( - lambda x: torch.from_numpy(x.transpose(0, 3, 1, 2)).to("cpu"), + lambda x: torch.from_numpy(x.transpose(0, 3, 1, 2)) + .to("cpu") + .to(torch_dtype), (in_np, wt_np), ) @@ -312,7 +314,7 @@ class TestConv(mlx_tests.MLXTestCase): padding=padding, dilation=dilation, groups=groups, - ) + ).astype(mx.float32) out_pt = torch.conv2d( in_pt, wt_pt, @@ -321,12 +323,20 @@ class TestConv(mlx_tests.MLXTestCase): dilation=dilation, groups=groups, ) - out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True) + out_pt = ( + torch.permute(out_pt, (0, 2, 3, 1)) + .to(torch.float32) + .numpy(force=True) + ) self.assertEqual(out_pt.shape, out_mx.shape) + if dtype == "bfloat16": + atol, rtol = 1e-1, 1e-3 + else: + atol, rtol = 1e-5, 1e-6 self.assertTrue(np.allclose(out_pt, out_mx, atol=atol)) - for dtype in ("float32",): + for dtype in ("float32", "bfloat16"): for N, C, O in ( (1, 1, 1), (1, 6, 1),