* added conv3d

added conv3d

implemented explicit_gemm_conv_ND_cpu and bounds checks for slow_conv_3D

* incorporated reviewer comments

* fixed test

* reduced tensor shapes in test for conv3d

* Reviewer suggestion

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

Reviewer suggestion

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

Reviewer suggestion

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

Reviewer suggestion
This commit is contained in:
Max-Heinrich Laves
2024-05-11 15:15:02 +02:00
committed by GitHub
parent a9f80d60f6
commit ff4223904d
10 changed files with 951 additions and 13 deletions

View File

@@ -399,7 +399,7 @@ class TestConv(mlx_tests.MLXTestCase):
[in_mx, wt_mx],
[ct_mx],
)
pt_grad_in = F.grad.conv1d_input(
pt_grad_in = F.grad.conv2d_input(
in_pt.shape,
wt_pt,
ct_pt,
@@ -408,7 +408,7 @@ class TestConv(mlx_tests.MLXTestCase):
dilation=dilation,
groups=groups,
)
pt_grad_wt = F.grad.conv1d_weight(
pt_grad_wt = F.grad.conv2d_weight(
in_pt,
wt_pt.shape,
ct_pt,
@@ -444,6 +444,203 @@ class TestConv(mlx_tests.MLXTestCase):
N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype
)
@unittest.skipIf(not has_torch, "requires Torch")
def test_torch_conv_3D(self):
def run_conv3D(
N,
C,
O,
idim,
kdim,
stride,
padding,
dilation=(1, 1, 1),
groups=1,
dtype="float32",
atol=1e-5,
):
with self.subTest(
dtype=dtype,
N=N,
C=C,
O=O,
idim=idim,
kdim=kdim,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
):
np_dtype = getattr(np, dtype)
np.random.seed(0)
iD, iH, iW = idim
kD, kH, kW = kdim
scale = 1.0 / math.sqrt(kD * kH * kW * C)
in_np = np.random.normal(0.0, scale, (N, iD, iH, iW, C)).astype(
np_dtype
)
wt_np = np.random.normal(0.0, 1.0, (O, kD, kH, kW, C)).astype(np_dtype)
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
in_pt, wt_pt = map(
lambda x: torch.from_numpy(x.transpose(0, 4, 1, 2, 3)).to("cpu"),
(in_np, wt_np),
)
out_mx = mx.conv3d(
in_mx,
wt_mx,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
out_pt = torch.conv3d(
in_pt,
wt_pt,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1)).numpy(force=True)
self.assertEqual(out_pt.shape, out_mx.shape)
self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
for dtype in ("float32",):
for N, C, O in (
(1, 1, 1),
(1, 6, 1),
(1, 1, 6),
(4, 16, 32),
):
for idim, kdim, stride, padding in (
((1, 1, 1), (1, 1, 1), (1, 1, 1), (0, 0, 0)),
((3, 3, 3), (3, 1, 1), (1, 1, 1), (0, 0, 0)),
((31, 31, 31), (5, 5, 5), (5, 5, 5), (2, 2, 2)),
):
run_conv3D(N, C, O, idim, kdim, stride, padding, dtype=dtype)
@unittest.skipIf(not has_torch, "requires Torch")
def test_torch_conv_3D_grad(self):
def run_conv3D_grad(
N,
C,
O,
idim,
kdim,
stride,
padding,
dilation=(1, 1, 1),
groups=1,
dtype="float32",
atol=1e-5,
):
with self.subTest(
dtype=dtype,
N=N,
C=C,
O=O,
idim=idim,
kdim=kdim,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
):
np_dtype = getattr(np, dtype)
np.random.seed(0)
iD, iH, iW = idim
kD, kH, kW = kdim
scale = 1.0 / math.sqrt(kD * kH * kW * C)
oD = 1 + (
(iD + 2 * padding[0] - dilation[0] * (kD - 1) - 1) // stride[0]
)
oH = 1 + (
(iH + 2 * padding[1] - dilation[1] * (kH - 1) - 1) // stride[1]
)
oW = 1 + (
(iW + 2 * padding[2] - dilation[2] * (kW - 1) - 1) // stride[2]
)
in_np = np.random.normal(0.0, scale, (N, iD, iH, iW, C)).astype(
np_dtype
)
wt_np = np.random.normal(0.0, scale, (O, kD, kH, kW, C)).astype(
np_dtype
)
ct_np = np.random.normal(0.0, scale, (N, oD, oH, oW, O)).astype(
np_dtype
)
in_mx, wt_mx, ct_mx = map(mx.array, (in_np, wt_np, ct_np))
in_pt, wt_pt, ct_pt = map(
lambda x: torch.from_numpy(x.transpose(0, 4, 1, 2, 3)).to("cpu"),
(in_np, wt_np, ct_np),
)
def f(a, b):
return mx.conv3d(
a,
b,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
_, outs_mx = mx.vjp(
f,
[in_mx, wt_mx],
[ct_mx],
)
pt_grad_in = F.grad.conv3d_input(
in_pt.shape,
wt_pt,
ct_pt,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
pt_grad_wt = F.grad.conv3d_weight(
in_pt,
wt_pt.shape,
ct_pt,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
pt_grad_in = torch.permute(pt_grad_in, (0, 2, 3, 4, 1)).numpy()
pt_grad_wt = torch.permute(pt_grad_wt, (0, 2, 3, 4, 1)).numpy()
mx_grad_in, mx_grad_wt = outs_mx
self.assertEqual(pt_grad_in.shape, mx_grad_in.shape)
self.assertEqual(in_mx.shape, mx_grad_in.shape)
self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))
self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape)
self.assertEqual(wt_mx.shape, mx_grad_wt.shape)
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
for dtype in ("float32",):
for N, C, O in ((1, 1, 1), (1, 6, 1), (1, 1, 6), (4, 16, 32), (4, 8, 16)):
for idim, kdim, stride, padding, dilation in (
((1, 1, 1), (1, 1, 1), (1, 1, 1), (0, 0, 0), (1, 1, 1)),
((3, 3, 3), (3, 1, 1), (1, 1, 1), (0, 0, 0), (1, 1, 1)),
((15, 15, 15), (5, 5, 5), (5, 5, 5), (2, 2, 2), (1, 1, 1)),
((16, 16, 16), (3, 3, 3), (2, 2, 2), (1, 1, 1), (1, 1, 1)),
((15, 15, 15), (5, 5, 5), (5, 5, 5), (2, 2, 2), (3, 2, 2)),
((16, 16, 16), (3, 3, 3), (2, 2, 2), (1, 1, 1), (3, 2, 2)),
):
run_conv3D_grad(
N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype
)
def __conv_general_test(
self,
in_shape,