mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
add bfloat conv for windograd (#1306)
* add bfloat conv for windograd * accumulate in fp32 * accumulate in fp32 * accumulate in bf16
This commit is contained in:
parent
10b5835501
commit
58d0e199e1
@ -344,12 +344,12 @@ winograd_conv_2d_weight_transform(
|
|||||||
const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
||||||
|
|
||||||
// Initialize G matrix
|
// Initialize G matrix
|
||||||
simdgroup_matrix<T, 8, 8> G;
|
simdgroup_matrix<float, 8, 8> G;
|
||||||
G.thread_elements()[0] = WGT::wt_transform[sm][sn];
|
G.thread_elements()[0] = WGT::wt_transform[sm][sn];
|
||||||
G.thread_elements()[1] = WGT::wt_transform[sm][sn + 1];
|
G.thread_elements()[1] = WGT::wt_transform[sm][sn + 1];
|
||||||
|
|
||||||
// Initialize Gt matrix
|
// Initialize Gt matrix
|
||||||
simdgroup_matrix<T, 8, 8> Gt;
|
simdgroup_matrix<float, 8, 8> Gt;
|
||||||
Gt.thread_elements()[0] = WGT::wt_transform[sn][sm];
|
Gt.thread_elements()[0] = WGT::wt_transform[sn][sm];
|
||||||
Gt.thread_elements()[1] = WGT::wt_transform[sn + 1][sm];
|
Gt.thread_elements()[1] = WGT::wt_transform[sn + 1][sm];
|
||||||
|
|
||||||
@ -381,15 +381,15 @@ winograd_conv_2d_weight_transform(
|
|||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
// Do transform and store the result
|
// Do transform and store the result
|
||||||
for (int c = 0; c < BC; ++c) {
|
for (int c = 0; c < BC; ++c) {
|
||||||
simdgroup_matrix<T, 8, 8> g;
|
simdgroup_matrix<float, 8, 8> g;
|
||||||
g.thread_elements()[0] =
|
g.thread_elements()[0] =
|
||||||
sm < R && sn < R ? Ws[simd_group_id][sm][sn][c] : T(0);
|
sm < R && sn < R ? Ws[simd_group_id][sm][sn][c] : T(0);
|
||||||
g.thread_elements()[1] =
|
g.thread_elements()[1] =
|
||||||
sm < R && sn + 1 < R ? Ws[simd_group_id][sm][sn + 1][c] : T(0);
|
sm < R && sn + 1 < R ? Ws[simd_group_id][sm][sn + 1][c] : T(0);
|
||||||
|
|
||||||
simdgroup_matrix<T, 8, 8> g_out = (G * g) * Gt;
|
simdgroup_matrix<float, 8, 8> g_out = (G * g) * Gt;
|
||||||
wt_out_0[c * O] = g_out.thread_elements()[0];
|
wt_out_0[c * O] = static_cast<T>(g_out.thread_elements()[0]);
|
||||||
wt_out_1[c * O] = g_out.thread_elements()[1];
|
wt_out_1[c * O] = static_cast<T>(g_out.thread_elements()[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
wt_in += BC;
|
wt_in += BC;
|
||||||
@ -433,12 +433,12 @@ winograd_conv_2d_input_transform(
|
|||||||
const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
||||||
|
|
||||||
// Initialize B matrix
|
// Initialize B matrix
|
||||||
simdgroup_matrix<T, 8, 8> B;
|
simdgroup_matrix<float, 8, 8> B;
|
||||||
B.thread_elements()[0] = WGT::in_transform[sm][sn];
|
B.thread_elements()[0] = WGT::in_transform[sm][sn];
|
||||||
B.thread_elements()[1] = WGT::in_transform[sm][sn + 1];
|
B.thread_elements()[1] = WGT::in_transform[sm][sn + 1];
|
||||||
|
|
||||||
// Initialize Bt matrix
|
// Initialize Bt matrix
|
||||||
simdgroup_matrix<T, 8, 8> Bt;
|
simdgroup_matrix<float, 8, 8> Bt;
|
||||||
Bt.thread_elements()[0] = WGT::in_transform[sn][sm];
|
Bt.thread_elements()[0] = WGT::in_transform[sn][sm];
|
||||||
Bt.thread_elements()[1] = WGT::in_transform[sn + 1][sm];
|
Bt.thread_elements()[1] = WGT::in_transform[sn + 1][sm];
|
||||||
|
|
||||||
@ -493,13 +493,13 @@ winograd_conv_2d_input_transform(
|
|||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
// Do transform and store the result
|
// Do transform and store the result
|
||||||
for (int c = simd_group_id; c < BC; c += N_SIMD_GROUPS) {
|
for (int c = simd_group_id; c < BC; c += N_SIMD_GROUPS) {
|
||||||
simdgroup_matrix<T, 8, 8> I;
|
simdgroup_matrix<float, 8, 8> I;
|
||||||
I.thread_elements()[0] = Is[sm][sn][c];
|
I.thread_elements()[0] = Is[sm][sn][c];
|
||||||
I.thread_elements()[1] = Is[sm][sn + 1][c];
|
I.thread_elements()[1] = Is[sm][sn + 1][c];
|
||||||
|
|
||||||
simdgroup_matrix<T, 8, 8> I_out = (Bt * I) * B;
|
simdgroup_matrix<float, 8, 8> I_out = (Bt * I) * B;
|
||||||
inp_out_0[c] = I_out.thread_elements()[0];
|
inp_out_0[c] = static_cast<T>(I_out.thread_elements()[0]);
|
||||||
inp_out_1[c] = I_out.thread_elements()[1];
|
inp_out_1[c] = static_cast<T>(I_out.thread_elements()[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
inp_in += BC;
|
inp_in += BC;
|
||||||
@ -543,12 +543,12 @@ winograd_conv_2d_output_transform(
|
|||||||
const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
||||||
|
|
||||||
// Initialize A matrix
|
// Initialize A matrix
|
||||||
simdgroup_matrix<T, 8, 8> B;
|
simdgroup_matrix<float, 8, 8> B;
|
||||||
B.thread_elements()[0] = WGT::out_transform[sm][sn];
|
B.thread_elements()[0] = WGT::out_transform[sm][sn];
|
||||||
B.thread_elements()[1] = WGT::out_transform[sm][sn + 1];
|
B.thread_elements()[1] = WGT::out_transform[sm][sn + 1];
|
||||||
|
|
||||||
// Initialize At matrix
|
// Initialize At matrix
|
||||||
simdgroup_matrix<T, 8, 8> Bt;
|
simdgroup_matrix<float, 8, 8> Bt;
|
||||||
Bt.thread_elements()[0] = WGT::out_transform[sn][sm];
|
Bt.thread_elements()[0] = WGT::out_transform[sn][sm];
|
||||||
Bt.thread_elements()[1] = WGT::out_transform[sn + 1][sm];
|
Bt.thread_elements()[1] = WGT::out_transform[sn + 1][sm];
|
||||||
|
|
||||||
@ -597,16 +597,16 @@ winograd_conv_2d_output_transform(
|
|||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
// Do transform and store the result
|
// Do transform and store the result
|
||||||
for (int c = simd_group_id; c < BO; c += N_SIMD_GROUPS) {
|
for (int c = simd_group_id; c < BO; c += N_SIMD_GROUPS) {
|
||||||
simdgroup_matrix<T, 8, 8> O_mat;
|
simdgroup_matrix<float, 8, 8> O_mat;
|
||||||
O_mat.thread_elements()[0] = out_in_0[c];
|
O_mat.thread_elements()[0] = out_in_0[c];
|
||||||
O_mat.thread_elements()[1] = out_in_1[c];
|
O_mat.thread_elements()[1] = out_in_1[c];
|
||||||
|
|
||||||
simdgroup_matrix<T, 8, 8> O_out = (Bt * (O_mat * B));
|
simdgroup_matrix<float, 8, 8> O_out = (Bt * (O_mat * B));
|
||||||
if ((sm < M) && (sn < M)) {
|
if ((sm < M) && (sn < M)) {
|
||||||
Os[sm][sn][c] = O_out.thread_elements()[0];
|
Os[sm][sn][c] = static_cast<T>(O_out.thread_elements()[0]);
|
||||||
}
|
}
|
||||||
if ((sm < M) && ((sn + 1) < M)) {
|
if ((sm < M) && ((sn + 1) < M)) {
|
||||||
Os[sm][sn + 1][c] = O_out.thread_elements()[1];
|
Os[sm][sn + 1][c] = static_cast<T>(O_out.thread_elements()[1]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -650,4 +650,5 @@ winograd_conv_2d_output_transform(
|
|||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
instantiate_winograd_conv_2d(float32, float);
|
instantiate_winograd_conv_2d(float32, float);
|
||||||
|
instantiate_winograd_conv_2d(bfloat16, bfloat16_t);
|
||||||
instantiate_winograd_conv_2d(float16, half); // clang-format on
|
instantiate_winograd_conv_2d(float16, half); // clang-format on
|
||||||
|
@ -275,7 +275,6 @@ class TestConv(mlx_tests.MLXTestCase):
|
|||||||
dilation=(1, 1),
|
dilation=(1, 1),
|
||||||
groups=1,
|
groups=1,
|
||||||
dtype="float32",
|
dtype="float32",
|
||||||
atol=1e-5,
|
|
||||||
):
|
):
|
||||||
with self.subTest(
|
with self.subTest(
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
@ -289,19 +288,22 @@ class TestConv(mlx_tests.MLXTestCase):
|
|||||||
dilation=dilation,
|
dilation=dilation,
|
||||||
groups=groups,
|
groups=groups,
|
||||||
):
|
):
|
||||||
np_dtype = getattr(np, dtype)
|
|
||||||
np.random.seed(0)
|
np.random.seed(0)
|
||||||
iH, iW = idim
|
iH, iW = idim
|
||||||
kH, kW = kdim
|
kH, kW = kdim
|
||||||
scale = 1.0 / math.sqrt(kH * kW * C)
|
scale = 1.0 / math.sqrt(kH * kW * C)
|
||||||
in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype)
|
in_np = np.random.normal(0.0, scale, (N, iH, iW, C))
|
||||||
wt_np = np.random.normal(0.0, 1.0, (O, kH, kW, int(C / groups))).astype(
|
wt_np = np.random.normal(0.0, 1.0, (O, kH, kW, int(C / groups)))
|
||||||
np_dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
mx_dtype = getattr(mx, dtype)
|
||||||
|
torch_dtype = getattr(torch, dtype)
|
||||||
|
in_mx, wt_mx = map(
|
||||||
|
lambda x: mx.array(x).astype(mx_dtype), (in_np, wt_np)
|
||||||
|
)
|
||||||
in_pt, wt_pt = map(
|
in_pt, wt_pt = map(
|
||||||
lambda x: torch.from_numpy(x.transpose(0, 3, 1, 2)).to("cpu"),
|
lambda x: torch.from_numpy(x.transpose(0, 3, 1, 2))
|
||||||
|
.to("cpu")
|
||||||
|
.to(torch_dtype),
|
||||||
(in_np, wt_np),
|
(in_np, wt_np),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -312,7 +314,7 @@ class TestConv(mlx_tests.MLXTestCase):
|
|||||||
padding=padding,
|
padding=padding,
|
||||||
dilation=dilation,
|
dilation=dilation,
|
||||||
groups=groups,
|
groups=groups,
|
||||||
)
|
).astype(mx.float32)
|
||||||
out_pt = torch.conv2d(
|
out_pt = torch.conv2d(
|
||||||
in_pt,
|
in_pt,
|
||||||
wt_pt,
|
wt_pt,
|
||||||
@ -321,12 +323,20 @@ class TestConv(mlx_tests.MLXTestCase):
|
|||||||
dilation=dilation,
|
dilation=dilation,
|
||||||
groups=groups,
|
groups=groups,
|
||||||
)
|
)
|
||||||
out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True)
|
out_pt = (
|
||||||
|
torch.permute(out_pt, (0, 2, 3, 1))
|
||||||
|
.to(torch.float32)
|
||||||
|
.numpy(force=True)
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEqual(out_pt.shape, out_mx.shape)
|
self.assertEqual(out_pt.shape, out_mx.shape)
|
||||||
|
if dtype == "bfloat16":
|
||||||
|
atol, rtol = 1e-1, 1e-3
|
||||||
|
else:
|
||||||
|
atol, rtol = 1e-5, 1e-6
|
||||||
self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
|
self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
|
||||||
|
|
||||||
for dtype in ("float32",):
|
for dtype in ("float32", "bfloat16"):
|
||||||
for N, C, O in (
|
for N, C, O in (
|
||||||
(1, 1, 1),
|
(1, 1, 1),
|
||||||
(1, 6, 1),
|
(1, 6, 1),
|
||||||
|
Loading…
Reference in New Issue
Block a user