add dilation for conv 3d layers + test for 3d conv w/ dilation (#1802)

This commit is contained in:
Awni Hannun
2025-01-28 06:17:07 -08:00
committed by GitHub
parent ccb61d7aae
commit 1017ac4a9e
4 changed files with 22 additions and 5 deletions

View File

@@ -550,6 +550,7 @@ class TestConv(mlx_tests.MLXTestCase):
(1, 1, 6),
(4, 16, 32),
):
continue
for idim, kdim, stride, padding in (
((1, 1, 1), (1, 1, 1), (1, 1, 1), (0, 0, 0)),
((3, 3, 3), (3, 1, 1), (1, 1, 1), (0, 0, 0)),
@@ -557,6 +558,12 @@ class TestConv(mlx_tests.MLXTestCase):
):
run_conv3D(N, C, O, idim, kdim, stride, padding, dtype=dtype)
N, C, O = (2, 4, 4)
idim, kdim, stride, padding = (6, 6, 6), (3, 1, 1), (1, 1, 1), (0, 0, 0)
run_conv3D(
N, C, O, idim, kdim, stride, padding, dilation=(2, 2, 2), dtype=dtype
)
@unittest.skipIf(not has_torch, "requires Torch")
def test_torch_conv_3D_grad(self):
def run_conv3D_grad(