mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Conv3d (#993)
* added conv3d added conv3d implemented explicit_gemm_conv_ND_cpu and bounds checks for slow_conv_3D * incorporated reviewer comments * fixed test * reduced tensor shapes in test for conv3d * Reviewer suggestion Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Reviewer suggestion Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Reviewer suggestion Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Reviewer suggestion
This commit is contained in:

committed by
GitHub

parent
a9f80d60f6
commit
ff4223904d
@@ -399,7 +399,7 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
[in_mx, wt_mx],
|
||||
[ct_mx],
|
||||
)
|
||||
pt_grad_in = F.grad.conv1d_input(
|
||||
pt_grad_in = F.grad.conv2d_input(
|
||||
in_pt.shape,
|
||||
wt_pt,
|
||||
ct_pt,
|
||||
@@ -408,7 +408,7 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
pt_grad_wt = F.grad.conv1d_weight(
|
||||
pt_grad_wt = F.grad.conv2d_weight(
|
||||
in_pt,
|
||||
wt_pt.shape,
|
||||
ct_pt,
|
||||
@@ -444,6 +444,203 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype
|
||||
)
|
||||
|
||||
@unittest.skipIf(not has_torch, "requires Torch")
|
||||
def test_torch_conv_3D(self):
|
||||
def run_conv3D(
|
||||
N,
|
||||
C,
|
||||
O,
|
||||
idim,
|
||||
kdim,
|
||||
stride,
|
||||
padding,
|
||||
dilation=(1, 1, 1),
|
||||
groups=1,
|
||||
dtype="float32",
|
||||
atol=1e-5,
|
||||
):
|
||||
with self.subTest(
|
||||
dtype=dtype,
|
||||
N=N,
|
||||
C=C,
|
||||
O=O,
|
||||
idim=idim,
|
||||
kdim=kdim,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
):
|
||||
np_dtype = getattr(np, dtype)
|
||||
np.random.seed(0)
|
||||
iD, iH, iW = idim
|
||||
kD, kH, kW = kdim
|
||||
scale = 1.0 / math.sqrt(kD * kH * kW * C)
|
||||
in_np = np.random.normal(0.0, scale, (N, iD, iH, iW, C)).astype(
|
||||
np_dtype
|
||||
)
|
||||
wt_np = np.random.normal(0.0, 1.0, (O, kD, kH, kW, C)).astype(np_dtype)
|
||||
|
||||
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
||||
in_pt, wt_pt = map(
|
||||
lambda x: torch.from_numpy(x.transpose(0, 4, 1, 2, 3)).to("cpu"),
|
||||
(in_np, wt_np),
|
||||
)
|
||||
|
||||
out_mx = mx.conv3d(
|
||||
in_mx,
|
||||
wt_mx,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
out_pt = torch.conv3d(
|
||||
in_pt,
|
||||
wt_pt,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1)).numpy(force=True)
|
||||
|
||||
self.assertEqual(out_pt.shape, out_mx.shape)
|
||||
self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
|
||||
|
||||
for dtype in ("float32",):
|
||||
for N, C, O in (
|
||||
(1, 1, 1),
|
||||
(1, 6, 1),
|
||||
(1, 1, 6),
|
||||
(4, 16, 32),
|
||||
):
|
||||
for idim, kdim, stride, padding in (
|
||||
((1, 1, 1), (1, 1, 1), (1, 1, 1), (0, 0, 0)),
|
||||
((3, 3, 3), (3, 1, 1), (1, 1, 1), (0, 0, 0)),
|
||||
((31, 31, 31), (5, 5, 5), (5, 5, 5), (2, 2, 2)),
|
||||
):
|
||||
run_conv3D(N, C, O, idim, kdim, stride, padding, dtype=dtype)
|
||||
|
||||
@unittest.skipIf(not has_torch, "requires Torch")
|
||||
def test_torch_conv_3D_grad(self):
|
||||
def run_conv3D_grad(
|
||||
N,
|
||||
C,
|
||||
O,
|
||||
idim,
|
||||
kdim,
|
||||
stride,
|
||||
padding,
|
||||
dilation=(1, 1, 1),
|
||||
groups=1,
|
||||
dtype="float32",
|
||||
atol=1e-5,
|
||||
):
|
||||
with self.subTest(
|
||||
dtype=dtype,
|
||||
N=N,
|
||||
C=C,
|
||||
O=O,
|
||||
idim=idim,
|
||||
kdim=kdim,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
):
|
||||
np_dtype = getattr(np, dtype)
|
||||
np.random.seed(0)
|
||||
iD, iH, iW = idim
|
||||
kD, kH, kW = kdim
|
||||
scale = 1.0 / math.sqrt(kD * kH * kW * C)
|
||||
|
||||
oD = 1 + (
|
||||
(iD + 2 * padding[0] - dilation[0] * (kD - 1) - 1) // stride[0]
|
||||
)
|
||||
oH = 1 + (
|
||||
(iH + 2 * padding[1] - dilation[1] * (kH - 1) - 1) // stride[1]
|
||||
)
|
||||
oW = 1 + (
|
||||
(iW + 2 * padding[2] - dilation[2] * (kW - 1) - 1) // stride[2]
|
||||
)
|
||||
|
||||
in_np = np.random.normal(0.0, scale, (N, iD, iH, iW, C)).astype(
|
||||
np_dtype
|
||||
)
|
||||
wt_np = np.random.normal(0.0, scale, (O, kD, kH, kW, C)).astype(
|
||||
np_dtype
|
||||
)
|
||||
ct_np = np.random.normal(0.0, scale, (N, oD, oH, oW, O)).astype(
|
||||
np_dtype
|
||||
)
|
||||
|
||||
in_mx, wt_mx, ct_mx = map(mx.array, (in_np, wt_np, ct_np))
|
||||
in_pt, wt_pt, ct_pt = map(
|
||||
lambda x: torch.from_numpy(x.transpose(0, 4, 1, 2, 3)).to("cpu"),
|
||||
(in_np, wt_np, ct_np),
|
||||
)
|
||||
|
||||
def f(a, b):
|
||||
return mx.conv3d(
|
||||
a,
|
||||
b,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
|
||||
_, outs_mx = mx.vjp(
|
||||
f,
|
||||
[in_mx, wt_mx],
|
||||
[ct_mx],
|
||||
)
|
||||
pt_grad_in = F.grad.conv3d_input(
|
||||
in_pt.shape,
|
||||
wt_pt,
|
||||
ct_pt,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
pt_grad_wt = F.grad.conv3d_weight(
|
||||
in_pt,
|
||||
wt_pt.shape,
|
||||
ct_pt,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
pt_grad_in = torch.permute(pt_grad_in, (0, 2, 3, 4, 1)).numpy()
|
||||
pt_grad_wt = torch.permute(pt_grad_wt, (0, 2, 3, 4, 1)).numpy()
|
||||
|
||||
mx_grad_in, mx_grad_wt = outs_mx
|
||||
|
||||
self.assertEqual(pt_grad_in.shape, mx_grad_in.shape)
|
||||
self.assertEqual(in_mx.shape, mx_grad_in.shape)
|
||||
self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))
|
||||
|
||||
self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape)
|
||||
self.assertEqual(wt_mx.shape, mx_grad_wt.shape)
|
||||
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
|
||||
|
||||
for dtype in ("float32",):
|
||||
for N, C, O in ((1, 1, 1), (1, 6, 1), (1, 1, 6), (4, 16, 32), (4, 8, 16)):
|
||||
for idim, kdim, stride, padding, dilation in (
|
||||
((1, 1, 1), (1, 1, 1), (1, 1, 1), (0, 0, 0), (1, 1, 1)),
|
||||
((3, 3, 3), (3, 1, 1), (1, 1, 1), (0, 0, 0), (1, 1, 1)),
|
||||
((15, 15, 15), (5, 5, 5), (5, 5, 5), (2, 2, 2), (1, 1, 1)),
|
||||
((16, 16, 16), (3, 3, 3), (2, 2, 2), (1, 1, 1), (1, 1, 1)),
|
||||
((15, 15, 15), (5, 5, 5), (5, 5, 5), (2, 2, 2), (3, 2, 2)),
|
||||
((16, 16, 16), (3, 3, 3), (2, 2, 2), (1, 1, 1), (3, 2, 2)),
|
||||
):
|
||||
run_conv3D_grad(
|
||||
N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype
|
||||
)
|
||||
|
||||
def __conv_general_test(
|
||||
self,
|
||||
in_shape,
|
||||
|
Reference in New Issue
Block a user