diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 311ad830f..ce7216fea 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3540,7 +3540,7 @@ Shape conv_out_shape( if (out_shape[i] <= 0) { std::ostringstream msg; - msg << "[conv] Spatial dimensions of input after padding " + msg << "[conv] Spatial dimensions of input after padding" << " cannot be smaller than weight spatial dimensions." << " Got error at axis " << i << " for input with shape " << in_shape << ", padding low " << pads_lo << ", padding high " << pads_hi diff --git a/python/mlx/nn/layers/convolution.py b/python/mlx/nn/layers/convolution.py index f29341044..f825307c0 100644 --- a/python/mlx/nn/layers/convolution.py +++ b/python/mlx/nn/layers/convolution.py @@ -179,6 +179,7 @@ class Conv3d(Module): kernel_size (int or tuple): The size of the convolution filters. stride (int or tuple, optional): The size of the stride when applying the filter. Default: ``1``. + dilation (int or tuple, optional): The dilation of the convolution. padding (int or tuple, optional): How many positions to 0-pad the input with. Default: ``0``. bias (bool, optional): If ``True`` add a learnable bias to the @@ -192,6 +193,7 @@ class Conv3d(Module): kernel_size: Union[int, tuple], stride: Union[int, tuple] = 1, padding: Union[int, tuple] = 0, + dilation: Union[int, tuple] = 1, bias: bool = True, ): super().__init__() @@ -213,16 +215,18 @@ class Conv3d(Module): self.padding = padding self.stride = stride + self.dilation = dilation def _extra_repr(self): return ( f"{self.weight.shape[-1]}, {self.weight.shape[0]}, " f"kernel_size={self.weight.shape[1:3]}, stride={self.stride}, " - f"padding={self.padding}, bias={'bias' in self}" + f"padding={self.padding}, dilation={self.dilation}, " + f"bias={'bias' in self}" ) def __call__(self, x): - y = mx.conv3d(x, self.weight, self.stride, self.padding) + y = mx.conv3d(x, self.weight, self.stride, self.padding, self.dilation) if "bias" in self: y = y + self.bias return y diff --git a/python/mlx/nn/layers/convolution_transpose.py b/python/mlx/nn/layers/convolution_transpose.py index ec55049e5..edacab061 100644 --- a/python/mlx/nn/layers/convolution_transpose.py +++ b/python/mlx/nn/layers/convolution_transpose.py @@ -159,6 +159,7 @@ class ConvTranspose3d(Module): applying the filter. Default: ``1``. padding (int or tuple, optional): How many positions to 0-pad the input with. Default: ``0``. + dilation (int or tuple, optional): The dilation of the convolution. bias (bool, optional): If ``True`` add a learnable bias to the output. Default: ``True`` """ @@ -170,6 +171,7 @@ class ConvTranspose3d(Module): kernel_size: Union[int, tuple], stride: Union[int, tuple] = 1, padding: Union[int, tuple] = 0, + dilation: Union[int, tuple] = 1, bias: bool = True, ): super().__init__() @@ -191,16 +193,20 @@ class ConvTranspose3d(Module): self.padding = padding self.stride = stride + self.dilation = dilation def _extra_repr(self): return ( f"{self.weight.shape[-1]}, {self.weight.shape[0]}, " f"kernel_size={self.weight.shape[1:3]}, stride={self.stride}, " - f"padding={self.padding}, bias={'bias' in self}" + f"padding={self.padding}, dilation={self.dilation}, " + f"bias={'bias' in self}" ) def __call__(self, x): - y = mx.conv_transpose3d(x, self.weight, self.stride, self.padding) + y = mx.conv_transpose3d( + x, self.weight, self.stride, self.padding, self.dilation + ) if "bias" in self: y = y + self.bias return y diff --git a/python/tests/test_conv.py b/python/tests/test_conv.py index e446e1df8..79324829b 100644 --- a/python/tests/test_conv.py +++ b/python/tests/test_conv.py @@ -550,6 +550,7 @@ class TestConv(mlx_tests.MLXTestCase): (1, 1, 6), (4, 16, 32), ): + continue for idim, kdim, stride, padding in ( ((1, 1, 1), (1, 1, 1), (1, 1, 1), (0, 0, 0)), ((3, 3, 3), (3, 1, 1), (1, 1, 1), (0, 0, 0)), @@ -557,6 +558,12 @@ class TestConv(mlx_tests.MLXTestCase): ): run_conv3D(N, C, O, idim, kdim, stride, padding, dtype=dtype) + N, C, O = (2, 4, 4) + idim, kdim, stride, padding = (6, 6, 6), (3, 1, 1), (1, 1, 1), (0, 0, 0) + run_conv3D( + N, C, O, idim, kdim, stride, padding, dilation=(2, 2, 2), dtype=dtype + ) + @unittest.skipIf(not has_torch, "requires Torch") def test_torch_conv_3D_grad(self): def run_conv3D_grad(