From f2b5ba49af3a75d2bb2ed04c761a345095b7b426 Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Tue, 22 Apr 2025 18:55:06 +0530 Subject: [PATCH] Added output_padding to nn.layers as well as tests for the same --- mlx/ops.h | 3 + python/mlx/nn/layers/convolution_transpose.py | 44 +++- python/tests/test_conv_transpose.py | 209 ++++++++++++++++++ 3 files changed, 249 insertions(+), 7 deletions(-) diff --git a/mlx/ops.h b/mlx/ops.h index e79ea235d..12e896af6 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1291,6 +1291,7 @@ array conv_transpose1d( int stride = 1, int padding = 0, int dilation = 1, + int output_padding = 0, int groups = 1, StreamOrDevice s = {}); @@ -1301,6 +1302,7 @@ array conv_transpose2d( const std::pair& stride = {1, 1}, const std::pair& padding = {0, 0}, const std::pair& dilation = {1, 1}, + const std::pair& output_padding = {0, 0}, int groups = 1, StreamOrDevice s = {}); @@ -1311,6 +1313,7 @@ array conv_transpose3d( const std::tuple& stride = {1, 1, 1}, const std::tuple& padding = {0, 0, 0}, const std::tuple& dilation = {1, 1, 1}, + const std::tuple& output_padding = {0, 0, 0}, int groups = 1, StreamOrDevice s = {}); diff --git a/python/mlx/nn/layers/convolution_transpose.py b/python/mlx/nn/layers/convolution_transpose.py index edacab061..ff321ac98 100644 --- a/python/mlx/nn/layers/convolution_transpose.py +++ b/python/mlx/nn/layers/convolution_transpose.py @@ -25,6 +25,8 @@ class ConvTranspose1d(Module): padding (int, optional): How many positions to 0-pad the input with. Default: ``0``. dilation (int, optional): The dilation of the convolution. + output_padding(int, optional): Additional size added to one side of the output shape. + Default: ``0``. bias (bool, optional): If ``True`` add a learnable bias to the output. Default: ``True`` """ @@ -37,6 +39,7 @@ class ConvTranspose1d(Module): stride: int = 1, padding: int = 0, dilation: int = 1, + output_padding: int = 0, bias: bool = True, ): super().__init__() @@ -53,18 +56,25 @@ class ConvTranspose1d(Module): self.padding = padding self.dilation = dilation self.stride = stride + self.output_padding = output_padding def _extra_repr(self): return ( f"{self.weight.shape[-1]}, {self.weight.shape[0]}, " f"kernel_size={self.weight.shape[1]}, stride={self.stride}, " f"padding={self.padding}, dilation={self.dilation}, " + f"output_padding={self.output_padding}, " f"bias={'bias' in self}" ) def __call__(self, x): y = mx.conv_transpose1d( - x, self.weight, self.stride, self.padding, self.dilation + x, + self.weight, + self.stride, + self.padding, + self.dilation, + self.output_padding, ) if "bias" in self: y = y + self.bias @@ -90,6 +100,8 @@ class ConvTranspose2d(Module): padding (int or tuple, optional): How many positions to 0-pad the input with. Default: ``0``. dilation (int or tuple, optional): The dilation of the convolution. + output_padding(int or tuple, optional): Additional size added to one side of the output shape. + Default: ``0``. bias (bool, optional): If ``True`` add a learnable bias to the output. Default: ``True`` """ @@ -102,13 +114,14 @@ class ConvTranspose2d(Module): stride: Union[int, tuple] = 1, padding: Union[int, tuple] = 0, dilation: Union[int, tuple] = 1, + output_padding: Union[int, tuple] = 0, bias: bool = True, ): super().__init__() - kernel_size, stride, padding = map( + kernel_size, stride, padding, output_padding = map( lambda x: (x, x) if isinstance(x, int) else x, - (kernel_size, stride, padding), + (kernel_size, stride, padding, output_padding), ) scale = math.sqrt(1 / (in_channels * kernel_size[0] * kernel_size[1])) self.weight = mx.random.uniform( @@ -122,18 +135,25 @@ class ConvTranspose2d(Module): self.padding = padding self.stride = stride self.dilation = dilation + self.output_padding = output_padding def _extra_repr(self): return ( f"{self.weight.shape[-1]}, {self.weight.shape[0]}, " f"kernel_size={self.weight.shape[1:2]}, stride={self.stride}, " f"padding={self.padding}, dilation={self.dilation}, " + f"output_padding={self.output_padding}, " f"bias={'bias' in self}" ) def __call__(self, x): y = mx.conv_transpose2d( - x, self.weight, self.stride, self.padding, self.dilation + x, + self.weight, + self.stride, + self.padding, + self.dilation, + self.output_padding, ) if "bias" in self: y = y + self.bias @@ -160,6 +180,8 @@ class ConvTranspose3d(Module): padding (int or tuple, optional): How many positions to 0-pad the input with. Default: ``0``. dilation (int or tuple, optional): The dilation of the convolution. + output_padding(int or tuple, optional): Additional size added to one side of the output shape. + Default: ``0``. bias (bool, optional): If ``True`` add a learnable bias to the output. Default: ``True`` """ @@ -172,13 +194,14 @@ class ConvTranspose3d(Module): stride: Union[int, tuple] = 1, padding: Union[int, tuple] = 0, dilation: Union[int, tuple] = 1, + output_padding: Union[int, tuple] = 0, bias: bool = True, ): super().__init__() - kernel_size, stride, padding = map( + kernel_size, stride, padding, output_padding = map( lambda x: (x, x, x) if isinstance(x, int) else x, - (kernel_size, stride, padding), + (kernel_size, stride, padding, output_padding), ) scale = math.sqrt( 1 / (in_channels * kernel_size[0] * kernel_size[1] * kernel_size[2]) @@ -194,18 +217,25 @@ class ConvTranspose3d(Module): self.padding = padding self.stride = stride self.dilation = dilation + self.output_padding = output_padding def _extra_repr(self): return ( f"{self.weight.shape[-1]}, {self.weight.shape[0]}, " f"kernel_size={self.weight.shape[1:3]}, stride={self.stride}, " f"padding={self.padding}, dilation={self.dilation}, " + f"output_padding={self.output_padding}, " f"bias={'bias' in self}" ) def __call__(self, x): y = mx.conv_transpose3d( - x, self.weight, self.stride, self.padding, self.dilation + x, + self.weight, + self.stride, + self.padding, + self.dilation, + self.output_padding, ) if "bias" in self: y = y + self.bias diff --git a/python/tests/test_conv_transpose.py b/python/tests/test_conv_transpose.py index 1ac20cbb1..2085e09d7 100644 --- a/python/tests/test_conv_transpose.py +++ b/python/tests/test_conv_transpose.py @@ -596,6 +596,215 @@ class TestConvTranspose(mlx_tests.MLXTestCase): N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype ) + @unittest.skipIf(not has_torch, "requires Torch") + def test_torch_conv_tranpose_1d_output_padding(self): + def run_conv_transpose_1d_output_padding( + N, C, O, iH, kH, stride, padding, output_padding, dtype="float32", atol=1e-5 + ): + with self.subTest( + dtype=dtype, + N=N, + C=C, + O=O, + iH=iH, + kH=kH, + stride=stride, + padding=padding, + output_padding=output_padding, + ): + np_dtype = getattr(np, dtype) + np.random.seed(0) + in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype) + wt_np = np.random.normal(0, 1.0 / C, (O, kH, C)).astype(np_dtype) + + in_mx, wt_mx = map(mx.array, (in_np, wt_np)) + in_pt = torch.from_numpy(in_np.transpose(0, 2, 1)) + wt_pt = torch.from_numpy(wt_np.transpose(2, 0, 1)) + + out_mx = mx.conv_transpose1d( + in_mx, + wt_mx, + stride=stride, + padding=padding, + output_padding=output_padding, + ) + + out_pt = torch.conv_transpose1d( + in_pt, + wt_pt, + stride=stride, + padding=padding, + output_padding=output_padding, + ) + out_pt = torch.transpose(out_pt, 2, 1) + + self.assertEqual(out_pt.shape, out_mx.shape) + self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=atol)) + + for dtype in ("float32",): + for N, C, O in ((1, 1, 1), (1, 6, 1), (4, 32, 64)): + for iH, kH, stride, padding, output_padding in ( + (3, 2, 2, 0, 1), + (5, 3, 2, 1, 0), + (7, 4, 3, 1, 2), + ): + run_conv_transpose_1d_output_padding( + N, C, O, iH, kH, stride, padding, output_padding, dtype=dtype + ) + + @unittest.skipIf(not has_torch, "requires Torch") + def test_torch_conv_transpose_2d_output_padding(self): + def run_conv_transpose_2d_output_padding( + N, + C, + O, + idim, + kdim, + stride, + padding, + output_padding, + dtype="float32", + atol=1e-5, + ): + with self.subTest( + dtype=dtype, + N=N, + C=C, + O=O, + idim=idim, + kdim=kdim, + stride=stride, + padding=padding, + output_padding=output_padding, + ): + np_dtype = getattr(np, dtype) + np.random.seed(0) + iH, iW = idim + kH, kW = kdim + in_np = np.random.normal(0, 1.0 / C, (N, iH, iW, C)).astype(np_dtype) + wt_np = np.random.normal(0, 1.0 / C, (O, kH, kW, C)).astype(np_dtype) + + in_mx, wt_mx = map(mx.array, (in_np, wt_np)) + in_pt = torch.from_numpy(in_np.transpose(0, 3, 1, 2)) + wt_pt = torch.from_numpy(wt_np.transpose(3, 0, 1, 2)) + + out_mx = mx.conv_transpose2d( + in_mx, + wt_mx, + stride=stride, + padding=padding, + output_padding=output_padding, + ) + + out_pt = torch.conv_transpose2d( + in_pt, + wt_pt, + stride=stride, + padding=padding, + output_padding=output_padding, + ) + out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True) + + self.assertEqual(out_pt.shape, out_mx.shape) + self.assertTrue(np.allclose(out_pt, out_mx, atol=atol)) + + for dtype in ("float32",): + for N, C, O in ((1, 1, 1), (1, 6, 1), (4, 32, 64)): + for idim, kdim, stride, padding, output_padding in ( + ((3, 3), (2, 2), (2, 2), (0, 0), (1, 1)), + ((5, 5), (3, 3), (2, 2), (1, 1), (0, 0)), + ((7, 7), (4, 4), (3, 3), (1, 1), (2, 2)), + ): + run_conv_transpose_2d_output_padding( + N, + C, + O, + idim, + kdim, + stride, + padding, + output_padding, + dtype=dtype, + ) + + @unittest.skipIf(not has_torch, "requires Torch") + def test_torch_conv_transpose_3d_output_padding(self): + def run_conv_transpose_3d_output_padding( + N, + C, + O, + idim, + kdim, + stride, + padding, + output_padding, + dtype="float32", + atol=1e-5, + ): + with self.subTest( + dtype=dtype, + N=N, + C=C, + O=O, + idim=idim, + kdim=kdim, + stride=stride, + padding=padding, + output_padding=output_padding, + ): + np_dtype = getattr(np, dtype) + np.random.seed(0) + iD, iH, iW = idim + kD, kH, kW = kdim + in_np = np.random.normal(0, 1.0 / C, (N, iD, iH, iW, C)).astype( + np_dtype + ) + wt_np = np.random.normal(0, 1.0 / C, (O, kD, kH, kW, C)).astype( + np_dtype + ) + + in_mx, wt_mx = map(mx.array, (in_np, wt_np)) + in_pt = torch.from_numpy(in_np.transpose(0, 4, 1, 2, 3)) + wt_pt = torch.from_numpy(wt_np.transpose(4, 0, 1, 2, 3)) + + out_mx = mx.conv_transpose3d( + in_mx, + wt_mx, + stride=stride, + padding=padding, + output_padding=output_padding, + ) + out_pt = torch.conv_transpose3d( + in_pt, + wt_pt, + stride=stride, + padding=padding, + output_padding=output_padding, + ) + out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1)).numpy(force=True) + + self.assertEqual(out_pt.shape, out_mx.shape) + self.assertTrue(np.allclose(out_pt, out_mx, atol=atol)) + + for dtype in ("float32",): + for N, C, O in ((1, 1, 1), (1, 6, 1), (4, 32, 64)): + for idim, kdim, stride, padding, output_padding in ( + ((3, 3, 3), (2, 2, 2), (2, 2, 2), (0, 0, 0), (1, 1, 1)), + ((5, 5, 5), (3, 3, 3), (2, 2, 2), (1, 1, 1), (0, 0, 0)), + ((7, 7, 7), (4, 4, 4), (3, 3, 3), (1, 1, 1), (2, 2, 2)), + ): + run_conv_transpose_3d_output_padding( + N, + C, + O, + idim, + kdim, + stride, + padding, + output_padding, + dtype=dtype, + ) + if __name__ == "__main__": unittest.main()