mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-04 15:04:40 +08:00
add dilation for conv 3d layers + test for 3d conv w/ dilation (#1802)
This commit is contained in:
@@ -550,6 +550,7 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
(1, 1, 6),
|
||||
(4, 16, 32),
|
||||
):
|
||||
continue
|
||||
for idim, kdim, stride, padding in (
|
||||
((1, 1, 1), (1, 1, 1), (1, 1, 1), (0, 0, 0)),
|
||||
((3, 3, 3), (3, 1, 1), (1, 1, 1), (0, 0, 0)),
|
||||
@@ -557,6 +558,12 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
):
|
||||
run_conv3D(N, C, O, idim, kdim, stride, padding, dtype=dtype)
|
||||
|
||||
N, C, O = (2, 4, 4)
|
||||
idim, kdim, stride, padding = (6, 6, 6), (3, 1, 1), (1, 1, 1), (0, 0, 0)
|
||||
run_conv3D(
|
||||
N, C, O, idim, kdim, stride, padding, dilation=(2, 2, 2), dtype=dtype
|
||||
)
|
||||
|
||||
@unittest.skipIf(not has_torch, "requires Torch")
|
||||
def test_torch_conv_3D_grad(self):
|
||||
def run_conv3D_grad(
|
||||
|
Reference in New Issue
Block a user