mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-05 07:34:42 +08:00
add bfloat conv for windograd (#1306)
* add bfloat conv for windograd * accumulate in fp32 * accumulate in fp32 * accumulate in bf16
This commit is contained in:
@@ -275,7 +275,6 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
dilation=(1, 1),
|
||||
groups=1,
|
||||
dtype="float32",
|
||||
atol=1e-5,
|
||||
):
|
||||
with self.subTest(
|
||||
dtype=dtype,
|
||||
@@ -289,19 +288,22 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
):
|
||||
np_dtype = getattr(np, dtype)
|
||||
np.random.seed(0)
|
||||
iH, iW = idim
|
||||
kH, kW = kdim
|
||||
scale = 1.0 / math.sqrt(kH * kW * C)
|
||||
in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype)
|
||||
wt_np = np.random.normal(0.0, 1.0, (O, kH, kW, int(C / groups))).astype(
|
||||
np_dtype
|
||||
)
|
||||
in_np = np.random.normal(0.0, scale, (N, iH, iW, C))
|
||||
wt_np = np.random.normal(0.0, 1.0, (O, kH, kW, int(C / groups)))
|
||||
|
||||
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
||||
mx_dtype = getattr(mx, dtype)
|
||||
torch_dtype = getattr(torch, dtype)
|
||||
in_mx, wt_mx = map(
|
||||
lambda x: mx.array(x).astype(mx_dtype), (in_np, wt_np)
|
||||
)
|
||||
in_pt, wt_pt = map(
|
||||
lambda x: torch.from_numpy(x.transpose(0, 3, 1, 2)).to("cpu"),
|
||||
lambda x: torch.from_numpy(x.transpose(0, 3, 1, 2))
|
||||
.to("cpu")
|
||||
.to(torch_dtype),
|
||||
(in_np, wt_np),
|
||||
)
|
||||
|
||||
@@ -312,7 +314,7 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
).astype(mx.float32)
|
||||
out_pt = torch.conv2d(
|
||||
in_pt,
|
||||
wt_pt,
|
||||
@@ -321,12 +323,20 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True)
|
||||
out_pt = (
|
||||
torch.permute(out_pt, (0, 2, 3, 1))
|
||||
.to(torch.float32)
|
||||
.numpy(force=True)
|
||||
)
|
||||
|
||||
self.assertEqual(out_pt.shape, out_mx.shape)
|
||||
if dtype == "bfloat16":
|
||||
atol, rtol = 1e-1, 1e-3
|
||||
else:
|
||||
atol, rtol = 1e-5, 1e-6
|
||||
self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
|
||||
|
||||
for dtype in ("float32",):
|
||||
for dtype in ("float32", "bfloat16"):
|
||||
for N, C, O in (
|
||||
(1, 1, 1),
|
||||
(1, 6, 1),
|
||||
|
Reference in New Issue
Block a user