diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 54ac62fef..c2aa4786f 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3769,6 +3769,7 @@ array conv_transpose_general( std::vector stride, std::vector padding, std::vector dilation, + std::vector output_padding, int groups, StreamOrDevice s) { std::vector padding_lo(padding.size()); @@ -3782,7 +3783,8 @@ array conv_transpose_general( int in_size = 1 + (conv_output_shape - 1); int out_size = 1 + stride[i] * (input.shape(1 + i) - 1); - padding_hi[i] = in_size - out_size + padding[i]; + padding_hi[i] = in_size - out_size + padding[i] + + output_padding[i]; // Adjust with output_padding } return conv_general( @@ -3805,10 +3807,11 @@ array conv_transpose1d( int stride /* = 1 */, int padding /* = 0 */, int dilation /* = 1 */, + int output_padding /* = 0 */, int groups /* = 1 */, StreamOrDevice s /* = {} */) { return conv_transpose_general( - in_, wt_, {stride}, {padding}, {dilation}, groups, s); + in_, wt_, {stride}, {padding}, {dilation}, {output_padding}, groups, s); } /** 2D transposed convolution with a filter */ @@ -3818,6 +3821,7 @@ array conv_transpose2d( const std::pair& stride /* = {1, 1} */, const std::pair& padding /* = {0, 0} */, const std::pair& dilation /* = {1, 1} */, + const std::pair& output_padding /* = {0, 0} */, int groups /* = 1 */, StreamOrDevice s /* = {} */) { return conv_transpose_general( @@ -3826,6 +3830,7 @@ array conv_transpose2d( {stride.first, stride.second}, {padding.first, padding.second}, {dilation.first, dilation.second}, + {output_padding.first, output_padding.second}, groups, s); } @@ -3837,6 +3842,7 @@ array conv_transpose3d( const std::tuple& stride /* = {1, 1, 1} */, const std::tuple& padding /* = {0, 0, 0} */, const std::tuple& dilation /* = {1, 1, 1} */, + const std::tuple& output_padding /* = {0, 0, 0} */, int groups /* = 1 */, StreamOrDevice s /* = {} */) { return conv_transpose_general( @@ -3845,6 +3851,9 @@ array conv_transpose3d( {std::get<0>(stride), std::get<1>(stride), std::get<2>(stride)}, {std::get<0>(padding), std::get<1>(padding), std::get<2>(padding)}, {std::get<0>(dilation), std::get<1>(dilation), std::get<2>(dilation)}, + {std::get<0>(output_padding), + std::get<1>(output_padding), + std::get<2>(output_padding)}, groups, s); } diff --git a/mlx/ops.h b/mlx/ops.h index e79ea235d..12e896af6 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1291,6 +1291,7 @@ array conv_transpose1d( int stride = 1, int padding = 0, int dilation = 1, + int output_padding = 0, int groups = 1, StreamOrDevice s = {}); @@ -1301,6 +1302,7 @@ array conv_transpose2d( const std::pair& stride = {1, 1}, const std::pair& padding = {0, 0}, const std::pair& dilation = {1, 1}, + const std::pair& output_padding = {0, 0}, int groups = 1, StreamOrDevice s = {}); @@ -1311,6 +1313,7 @@ array conv_transpose3d( const std::tuple& stride = {1, 1, 1}, const std::tuple& padding = {0, 0, 0}, const std::tuple& dilation = {1, 1, 1}, + const std::tuple& output_padding = {0, 0, 0}, int groups = 1, StreamOrDevice s = {}); diff --git a/python/mlx/nn/layers/convolution_transpose.py b/python/mlx/nn/layers/convolution_transpose.py index edacab061..a11c4cb40 100644 --- a/python/mlx/nn/layers/convolution_transpose.py +++ b/python/mlx/nn/layers/convolution_transpose.py @@ -25,6 +25,8 @@ class ConvTranspose1d(Module): padding (int, optional): How many positions to 0-pad the input with. Default: ``0``. dilation (int, optional): The dilation of the convolution. + output_padding(int, optional): Additional size added to one side of the + output shape. Default: ``0``. bias (bool, optional): If ``True`` add a learnable bias to the output. Default: ``True`` """ @@ -37,6 +39,7 @@ class ConvTranspose1d(Module): stride: int = 1, padding: int = 0, dilation: int = 1, + output_padding: int = 0, bias: bool = True, ): super().__init__() @@ -53,18 +56,25 @@ class ConvTranspose1d(Module): self.padding = padding self.dilation = dilation self.stride = stride + self.output_padding = output_padding def _extra_repr(self): return ( f"{self.weight.shape[-1]}, {self.weight.shape[0]}, " f"kernel_size={self.weight.shape[1]}, stride={self.stride}, " f"padding={self.padding}, dilation={self.dilation}, " + f"output_padding={self.output_padding}, " f"bias={'bias' in self}" ) def __call__(self, x): y = mx.conv_transpose1d( - x, self.weight, self.stride, self.padding, self.dilation + x, + self.weight, + self.stride, + self.padding, + self.dilation, + self.output_padding, ) if "bias" in self: y = y + self.bias @@ -90,6 +100,8 @@ class ConvTranspose2d(Module): padding (int or tuple, optional): How many positions to 0-pad the input with. Default: ``0``. dilation (int or tuple, optional): The dilation of the convolution. + output_padding(int or tuple, optional): Additional size added to one + side of the output shape. Default: ``0``. bias (bool, optional): If ``True`` add a learnable bias to the output. Default: ``True`` """ @@ -102,13 +114,14 @@ class ConvTranspose2d(Module): stride: Union[int, tuple] = 1, padding: Union[int, tuple] = 0, dilation: Union[int, tuple] = 1, + output_padding: Union[int, tuple] = 0, bias: bool = True, ): super().__init__() - kernel_size, stride, padding = map( + kernel_size, stride, padding, output_padding = map( lambda x: (x, x) if isinstance(x, int) else x, - (kernel_size, stride, padding), + (kernel_size, stride, padding, output_padding), ) scale = math.sqrt(1 / (in_channels * kernel_size[0] * kernel_size[1])) self.weight = mx.random.uniform( @@ -122,18 +135,25 @@ class ConvTranspose2d(Module): self.padding = padding self.stride = stride self.dilation = dilation + self.output_padding = output_padding def _extra_repr(self): return ( f"{self.weight.shape[-1]}, {self.weight.shape[0]}, " f"kernel_size={self.weight.shape[1:2]}, stride={self.stride}, " f"padding={self.padding}, dilation={self.dilation}, " + f"output_padding={self.output_padding}, " f"bias={'bias' in self}" ) def __call__(self, x): y = mx.conv_transpose2d( - x, self.weight, self.stride, self.padding, self.dilation + x, + self.weight, + self.stride, + self.padding, + self.dilation, + self.output_padding, ) if "bias" in self: y = y + self.bias @@ -160,6 +180,8 @@ class ConvTranspose3d(Module): padding (int or tuple, optional): How many positions to 0-pad the input with. Default: ``0``. dilation (int or tuple, optional): The dilation of the convolution. + output_padding(int or tuple, optional): Additional size added to one + side of the output shape. Default: ``0``. bias (bool, optional): If ``True`` add a learnable bias to the output. Default: ``True`` """ @@ -172,13 +194,14 @@ class ConvTranspose3d(Module): stride: Union[int, tuple] = 1, padding: Union[int, tuple] = 0, dilation: Union[int, tuple] = 1, + output_padding: Union[int, tuple] = 0, bias: bool = True, ): super().__init__() - kernel_size, stride, padding = map( + kernel_size, stride, padding, output_padding = map( lambda x: (x, x, x) if isinstance(x, int) else x, - (kernel_size, stride, padding), + (kernel_size, stride, padding, output_padding), ) scale = math.sqrt( 1 / (in_channels * kernel_size[0] * kernel_size[1] * kernel_size[2]) @@ -194,18 +217,25 @@ class ConvTranspose3d(Module): self.padding = padding self.stride = stride self.dilation = dilation + self.output_padding = output_padding def _extra_repr(self): return ( f"{self.weight.shape[-1]}, {self.weight.shape[0]}, " f"kernel_size={self.weight.shape[1:3]}, stride={self.stride}, " f"padding={self.padding}, dilation={self.dilation}, " + f"output_padding={self.output_padding}, " f"bias={'bias' in self}" ) def __call__(self, x): y = mx.conv_transpose3d( - x, self.weight, self.stride, self.padding, self.dilation + x, + self.weight, + self.stride, + self.padding, + self.dilation, + self.output_padding, ) if "bias" in self: y = y + self.bias diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 5969c5052..60b6188ed 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3609,11 +3609,12 @@ void init_ops(nb::module_& m) { "stride"_a = 1, "padding"_a = 0, "dilation"_a = 1, + "output_padding"_a = 0, "groups"_a = 1, nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def conv_transpose1d(input: array, weight: array, /, stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"), + "def conv_transpose1d(input: array, weight: array, /, stride: int = 1, padding: int = 0, dilation: int = 1, output_padding: int = 0, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( 1D transposed convolution over an input with several channels @@ -3623,6 +3624,7 @@ void init_ops(nb::module_& m) { stride (int, optional): Kernel stride. Default: ``1``. padding (int, optional): Input padding. Default: ``0``. dilation (int, optional): Kernel dilation. Default: ``1``. + output_padding (int, optional): Output padding. Default: ``0``. groups (int, optional): Input feature groups. Default: ``1``. Returns: @@ -3635,11 +3637,13 @@ void init_ops(nb::module_& m) { const std::variant>& stride, const std::variant>& padding, const std::variant>& dilation, + const std::variant>& output_padding, int groups, mx::StreamOrDevice s) { std::pair stride_pair{1, 1}; std::pair padding_pair{0, 0}; std::pair dilation_pair{1, 1}; + std::pair output_padding_pair{0, 0}; if (auto pv = std::get_if(&stride); pv) { stride_pair = std::pair{*pv, *pv}; @@ -3659,19 +3663,33 @@ void init_ops(nb::module_& m) { dilation_pair = std::get>(dilation); } + if (auto pv = std::get_if(&output_padding); pv) { + output_padding_pair = std::pair{*pv, *pv}; + } else { + output_padding_pair = std::get>(output_padding); + } + return mx::conv_transpose2d( - input, weight, stride_pair, padding_pair, dilation_pair, groups, s); + input, + weight, + stride_pair, + padding_pair, + dilation_pair, + output_padding_pair, + groups, + s); }, nb::arg(), nb::arg(), "stride"_a = 1, "padding"_a = 0, "dilation"_a = 1, + "output_padding"_a = 0, "groups"_a = 1, nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def conv_transpose2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"), + "def conv_transpose2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, output_padding: Union[int, Tuple[int, int]] = 0, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( 2D transposed convolution over an input with several channels @@ -3689,6 +3707,9 @@ void init_ops(nb::module_& m) { dilation (int or tuple(int), optional): :obj:`tuple` of size 2 with kernel dilation. All spatial dimensions get the same dilation if only one number is specified. Default: ``1`` + output_padding (int or tuple(int), optional): :obj:`tuple` of size 2 with + output padding. All spatial dimensions get the same output + padding if only one number is specified. Default: ``0``. groups (int, optional): input feature groups. Default: ``1``. Returns: @@ -3701,11 +3722,13 @@ void init_ops(nb::module_& m) { const std::variant>& stride, const std::variant>& padding, const std::variant>& dilation, + const std::variant>& output_padding, int groups, mx::StreamOrDevice s) { std::tuple stride_tuple{1, 1, 1}; std::tuple padding_tuple{0, 0, 0}; std::tuple dilation_tuple{1, 1, 1}; + std::tuple output_padding_tuple{0, 0, 0}; if (auto pv = std::get_if(&stride); pv) { stride_tuple = std::tuple{*pv, *pv, *pv}; @@ -3725,12 +3748,20 @@ void init_ops(nb::module_& m) { dilation_tuple = std::get>(dilation); } + if (auto pv = std::get_if(&output_padding); pv) { + output_padding_tuple = std::tuple{*pv, *pv, *pv}; + } else { + output_padding_tuple = + std::get>(output_padding); + } + return mx::conv_transpose3d( input, weight, stride_tuple, padding_tuple, dilation_tuple, + output_padding_tuple, groups, s); }, @@ -3739,11 +3770,12 @@ void init_ops(nb::module_& m) { "stride"_a = 1, "padding"_a = 0, "dilation"_a = 1, + "output_padding"_a = 0, "groups"_a = 1, nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def conv_transpose3d(input: array, weight: array, /, stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, dilation: Union[int, Tuple[int, int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"), + "def conv_transpose3d(input: array, weight: array, /, stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, dilation: Union[int, Tuple[int, int, int]] = 1, output_padding: Union[int, Tuple[int, int, int]] = 0, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( 3D transposed convolution over an input with several channels @@ -3761,6 +3793,9 @@ void init_ops(nb::module_& m) { dilation (int or tuple(int), optional): :obj:`tuple` of size 3 with kernel dilation. All spatial dimensions get the same dilation if only one number is specified. Default: ``1`` + output_padding (int or tuple(int), optional): :obj:`tuple` of size 3 with + output padding. All spatial dimensions get the same output + padding if only one number is specified. Default: ``0``. groups (int, optional): input feature groups. Default: ``1``. Returns: diff --git a/python/tests/test_conv_transpose.py b/python/tests/test_conv_transpose.py index 1ac20cbb1..2085e09d7 100644 --- a/python/tests/test_conv_transpose.py +++ b/python/tests/test_conv_transpose.py @@ -596,6 +596,215 @@ class TestConvTranspose(mlx_tests.MLXTestCase): N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype ) + @unittest.skipIf(not has_torch, "requires Torch") + def test_torch_conv_tranpose_1d_output_padding(self): + def run_conv_transpose_1d_output_padding( + N, C, O, iH, kH, stride, padding, output_padding, dtype="float32", atol=1e-5 + ): + with self.subTest( + dtype=dtype, + N=N, + C=C, + O=O, + iH=iH, + kH=kH, + stride=stride, + padding=padding, + output_padding=output_padding, + ): + np_dtype = getattr(np, dtype) + np.random.seed(0) + in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype) + wt_np = np.random.normal(0, 1.0 / C, (O, kH, C)).astype(np_dtype) + + in_mx, wt_mx = map(mx.array, (in_np, wt_np)) + in_pt = torch.from_numpy(in_np.transpose(0, 2, 1)) + wt_pt = torch.from_numpy(wt_np.transpose(2, 0, 1)) + + out_mx = mx.conv_transpose1d( + in_mx, + wt_mx, + stride=stride, + padding=padding, + output_padding=output_padding, + ) + + out_pt = torch.conv_transpose1d( + in_pt, + wt_pt, + stride=stride, + padding=padding, + output_padding=output_padding, + ) + out_pt = torch.transpose(out_pt, 2, 1) + + self.assertEqual(out_pt.shape, out_mx.shape) + self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=atol)) + + for dtype in ("float32",): + for N, C, O in ((1, 1, 1), (1, 6, 1), (4, 32, 64)): + for iH, kH, stride, padding, output_padding in ( + (3, 2, 2, 0, 1), + (5, 3, 2, 1, 0), + (7, 4, 3, 1, 2), + ): + run_conv_transpose_1d_output_padding( + N, C, O, iH, kH, stride, padding, output_padding, dtype=dtype + ) + + @unittest.skipIf(not has_torch, "requires Torch") + def test_torch_conv_transpose_2d_output_padding(self): + def run_conv_transpose_2d_output_padding( + N, + C, + O, + idim, + kdim, + stride, + padding, + output_padding, + dtype="float32", + atol=1e-5, + ): + with self.subTest( + dtype=dtype, + N=N, + C=C, + O=O, + idim=idim, + kdim=kdim, + stride=stride, + padding=padding, + output_padding=output_padding, + ): + np_dtype = getattr(np, dtype) + np.random.seed(0) + iH, iW = idim + kH, kW = kdim + in_np = np.random.normal(0, 1.0 / C, (N, iH, iW, C)).astype(np_dtype) + wt_np = np.random.normal(0, 1.0 / C, (O, kH, kW, C)).astype(np_dtype) + + in_mx, wt_mx = map(mx.array, (in_np, wt_np)) + in_pt = torch.from_numpy(in_np.transpose(0, 3, 1, 2)) + wt_pt = torch.from_numpy(wt_np.transpose(3, 0, 1, 2)) + + out_mx = mx.conv_transpose2d( + in_mx, + wt_mx, + stride=stride, + padding=padding, + output_padding=output_padding, + ) + + out_pt = torch.conv_transpose2d( + in_pt, + wt_pt, + stride=stride, + padding=padding, + output_padding=output_padding, + ) + out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True) + + self.assertEqual(out_pt.shape, out_mx.shape) + self.assertTrue(np.allclose(out_pt, out_mx, atol=atol)) + + for dtype in ("float32",): + for N, C, O in ((1, 1, 1), (1, 6, 1), (4, 32, 64)): + for idim, kdim, stride, padding, output_padding in ( + ((3, 3), (2, 2), (2, 2), (0, 0), (1, 1)), + ((5, 5), (3, 3), (2, 2), (1, 1), (0, 0)), + ((7, 7), (4, 4), (3, 3), (1, 1), (2, 2)), + ): + run_conv_transpose_2d_output_padding( + N, + C, + O, + idim, + kdim, + stride, + padding, + output_padding, + dtype=dtype, + ) + + @unittest.skipIf(not has_torch, "requires Torch") + def test_torch_conv_transpose_3d_output_padding(self): + def run_conv_transpose_3d_output_padding( + N, + C, + O, + idim, + kdim, + stride, + padding, + output_padding, + dtype="float32", + atol=1e-5, + ): + with self.subTest( + dtype=dtype, + N=N, + C=C, + O=O, + idim=idim, + kdim=kdim, + stride=stride, + padding=padding, + output_padding=output_padding, + ): + np_dtype = getattr(np, dtype) + np.random.seed(0) + iD, iH, iW = idim + kD, kH, kW = kdim + in_np = np.random.normal(0, 1.0 / C, (N, iD, iH, iW, C)).astype( + np_dtype + ) + wt_np = np.random.normal(0, 1.0 / C, (O, kD, kH, kW, C)).astype( + np_dtype + ) + + in_mx, wt_mx = map(mx.array, (in_np, wt_np)) + in_pt = torch.from_numpy(in_np.transpose(0, 4, 1, 2, 3)) + wt_pt = torch.from_numpy(wt_np.transpose(4, 0, 1, 2, 3)) + + out_mx = mx.conv_transpose3d( + in_mx, + wt_mx, + stride=stride, + padding=padding, + output_padding=output_padding, + ) + out_pt = torch.conv_transpose3d( + in_pt, + wt_pt, + stride=stride, + padding=padding, + output_padding=output_padding, + ) + out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1)).numpy(force=True) + + self.assertEqual(out_pt.shape, out_mx.shape) + self.assertTrue(np.allclose(out_pt, out_mx, atol=atol)) + + for dtype in ("float32",): + for N, C, O in ((1, 1, 1), (1, 6, 1), (4, 32, 64)): + for idim, kdim, stride, padding, output_padding in ( + ((3, 3, 3), (2, 2, 2), (2, 2, 2), (0, 0, 0), (1, 1, 1)), + ((5, 5, 5), (3, 3, 3), (2, 2, 2), (1, 1, 1), (0, 0, 0)), + ((7, 7, 7), (4, 4, 4), (3, 3, 3), (1, 1, 1), (2, 2, 2)), + ): + run_conv_transpose_3d_output_padding( + N, + C, + O, + idim, + kdim, + stride, + padding, + output_padding, + dtype=dtype, + ) + if __name__ == "__main__": unittest.main() diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index de0f3352c..c4f319d46 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -3911,4 +3911,70 @@ TEST_CASE("test bitwise shift operations") { CHECK_EQ(right_shift_bool_result.dtype(), uint8); CHECK(array_equal(right_shift_bool_result, full({4}, 0, uint8)).item()); -} \ No newline at end of file +} + +TEST_CASE("test conv_transpose1d with output_padding") { + auto in = array({1.0, 2.0, 3.0}, {1, 1, 3}); + auto wt = array({1.0, 1.0, 1.0}, {1, 1, 3}); + int stride = 2; + int padding = 0; + int dilation = 1; + int output_padding = 1; + int groups = 1; + + auto out = conv_transpose1d( + in, wt, stride, padding, dilation, output_padding, groups); + auto expected = array({6.0, 0.0}, {1, 2, 1}); + CHECK(array_equal(out, expected).item()); +} + +TEST_CASE("test conv_transpose2d with output_padding") { + auto in = array({1.0, 2.0, 3.0, 4.0}, {1, 1, 2, 2}); + auto wt = array({1.0, 1.0, 1.0, 1.0}, {2, 1, 1, 2}); + std::pair stride{2, 2}; + std::pair padding{0, 0}; + std::pair output_padding{1, 1}; + std::pair dilation{1, 1}; + int groups = 1; + + auto out = conv_transpose2d( + in, wt, stride, padding, dilation, output_padding, groups); + auto expected = array( + {3.0, + 3.0, + 0.0, + 0.0, + 7.0, + 7.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0}, + {1, 2, 4, 2}); + CHECK(array_equal(out, expected).item()); +} + +TEST_CASE("test conv_transpose3d with output_padding") { + auto in = array({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}, {1, 1, 2, 2, 2}); + auto wt = array({1.0, 1.0}, {1, 1, 1, 1, 2}); + std::tuple stride{2, 2, 2}; + std::tuple padding{0, 0, 0}; + std::tuple output_padding{1, 1, 1}; + std::tuple dilation{1, 1, 1}; + int groups = 1; + + auto out = conv_transpose3d( + in, wt, stride, padding, dilation, output_padding, groups); + auto expected = array( + {3.0, 0.0, 7.0, 0.0, 0.0, 0.0, 0.0, 0.0, 11.0, 0.0, 15.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, + {1, 2, 4, 4, 1}); + CHECK(array_equal(out, expected).item()); +}