Added output_padding parameters in conv_transpose (#2092)

This commit is contained in:
Param Thakkar
2025-04-23 21:56:33 +05:30
committed by GitHub
parent 3836445241
commit 600e87e03c
6 changed files with 366 additions and 14 deletions

View File

@@ -25,6 +25,8 @@ class ConvTranspose1d(Module):
padding (int, optional): How many positions to 0-pad the input with.
Default: ``0``.
dilation (int, optional): The dilation of the convolution.
output_padding(int, optional): Additional size added to one side of the
output shape. Default: ``0``.
bias (bool, optional): If ``True`` add a learnable bias to the output.
Default: ``True``
"""
@@ -37,6 +39,7 @@ class ConvTranspose1d(Module):
stride: int = 1,
padding: int = 0,
dilation: int = 1,
output_padding: int = 0,
bias: bool = True,
):
super().__init__()
@@ -53,18 +56,25 @@ class ConvTranspose1d(Module):
self.padding = padding
self.dilation = dilation
self.stride = stride
self.output_padding = output_padding
def _extra_repr(self):
return (
f"{self.weight.shape[-1]}, {self.weight.shape[0]}, "
f"kernel_size={self.weight.shape[1]}, stride={self.stride}, "
f"padding={self.padding}, dilation={self.dilation}, "
f"output_padding={self.output_padding}, "
f"bias={'bias' in self}"
)
def __call__(self, x):
y = mx.conv_transpose1d(
x, self.weight, self.stride, self.padding, self.dilation
x,
self.weight,
self.stride,
self.padding,
self.dilation,
self.output_padding,
)
if "bias" in self:
y = y + self.bias
@@ -90,6 +100,8 @@ class ConvTranspose2d(Module):
padding (int or tuple, optional): How many positions to 0-pad
the input with. Default: ``0``.
dilation (int or tuple, optional): The dilation of the convolution.
output_padding(int or tuple, optional): Additional size added to one
side of the output shape. Default: ``0``.
bias (bool, optional): If ``True`` add a learnable bias to the
output. Default: ``True``
"""
@@ -102,13 +114,14 @@ class ConvTranspose2d(Module):
stride: Union[int, tuple] = 1,
padding: Union[int, tuple] = 0,
dilation: Union[int, tuple] = 1,
output_padding: Union[int, tuple] = 0,
bias: bool = True,
):
super().__init__()
kernel_size, stride, padding = map(
kernel_size, stride, padding, output_padding = map(
lambda x: (x, x) if isinstance(x, int) else x,
(kernel_size, stride, padding),
(kernel_size, stride, padding, output_padding),
)
scale = math.sqrt(1 / (in_channels * kernel_size[0] * kernel_size[1]))
self.weight = mx.random.uniform(
@@ -122,18 +135,25 @@ class ConvTranspose2d(Module):
self.padding = padding
self.stride = stride
self.dilation = dilation
self.output_padding = output_padding
def _extra_repr(self):
return (
f"{self.weight.shape[-1]}, {self.weight.shape[0]}, "
f"kernel_size={self.weight.shape[1:2]}, stride={self.stride}, "
f"padding={self.padding}, dilation={self.dilation}, "
f"output_padding={self.output_padding}, "
f"bias={'bias' in self}"
)
def __call__(self, x):
y = mx.conv_transpose2d(
x, self.weight, self.stride, self.padding, self.dilation
x,
self.weight,
self.stride,
self.padding,
self.dilation,
self.output_padding,
)
if "bias" in self:
y = y + self.bias
@@ -160,6 +180,8 @@ class ConvTranspose3d(Module):
padding (int or tuple, optional): How many positions to 0-pad
the input with. Default: ``0``.
dilation (int or tuple, optional): The dilation of the convolution.
output_padding(int or tuple, optional): Additional size added to one
side of the output shape. Default: ``0``.
bias (bool, optional): If ``True`` add a learnable bias to the
output. Default: ``True``
"""
@@ -172,13 +194,14 @@ class ConvTranspose3d(Module):
stride: Union[int, tuple] = 1,
padding: Union[int, tuple] = 0,
dilation: Union[int, tuple] = 1,
output_padding: Union[int, tuple] = 0,
bias: bool = True,
):
super().__init__()
kernel_size, stride, padding = map(
kernel_size, stride, padding, output_padding = map(
lambda x: (x, x, x) if isinstance(x, int) else x,
(kernel_size, stride, padding),
(kernel_size, stride, padding, output_padding),
)
scale = math.sqrt(
1 / (in_channels * kernel_size[0] * kernel_size[1] * kernel_size[2])
@@ -194,18 +217,25 @@ class ConvTranspose3d(Module):
self.padding = padding
self.stride = stride
self.dilation = dilation
self.output_padding = output_padding
def _extra_repr(self):
return (
f"{self.weight.shape[-1]}, {self.weight.shape[0]}, "
f"kernel_size={self.weight.shape[1:3]}, stride={self.stride}, "
f"padding={self.padding}, dilation={self.dilation}, "
f"output_padding={self.output_padding}, "
f"bias={'bias' in self}"
)
def __call__(self, x):
y = mx.conv_transpose3d(
x, self.weight, self.stride, self.padding, self.dilation
x,
self.weight,
self.stride,
self.padding,
self.dilation,
self.output_padding,
)
if "bias" in self:
y = y + self.bias

View File

@@ -3609,11 +3609,12 @@ void init_ops(nb::module_& m) {
"stride"_a = 1,
"padding"_a = 0,
"dilation"_a = 1,
"output_padding"_a = 0,
"groups"_a = 1,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def conv_transpose1d(input: array, weight: array, /, stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
"def conv_transpose1d(input: array, weight: array, /, stride: int = 1, padding: int = 0, dilation: int = 1, output_padding: int = 0, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
1D transposed convolution over an input with several channels
@@ -3623,6 +3624,7 @@ void init_ops(nb::module_& m) {
stride (int, optional): Kernel stride. Default: ``1``.
padding (int, optional): Input padding. Default: ``0``.
dilation (int, optional): Kernel dilation. Default: ``1``.
output_padding (int, optional): Output padding. Default: ``0``.
groups (int, optional): Input feature groups. Default: ``1``.
Returns:
@@ -3635,11 +3637,13 @@ void init_ops(nb::module_& m) {
const std::variant<int, std::pair<int, int>>& stride,
const std::variant<int, std::pair<int, int>>& padding,
const std::variant<int, std::pair<int, int>>& dilation,
const std::variant<int, std::pair<int, int>>& output_padding,
int groups,
mx::StreamOrDevice s) {
std::pair<int, int> stride_pair{1, 1};
std::pair<int, int> padding_pair{0, 0};
std::pair<int, int> dilation_pair{1, 1};
std::pair<int, int> output_padding_pair{0, 0};
if (auto pv = std::get_if<int>(&stride); pv) {
stride_pair = std::pair<int, int>{*pv, *pv};
@@ -3659,19 +3663,33 @@ void init_ops(nb::module_& m) {
dilation_pair = std::get<std::pair<int, int>>(dilation);
}
if (auto pv = std::get_if<int>(&output_padding); pv) {
output_padding_pair = std::pair<int, int>{*pv, *pv};
} else {
output_padding_pair = std::get<std::pair<int, int>>(output_padding);
}
return mx::conv_transpose2d(
input, weight, stride_pair, padding_pair, dilation_pair, groups, s);
input,
weight,
stride_pair,
padding_pair,
dilation_pair,
output_padding_pair,
groups,
s);
},
nb::arg(),
nb::arg(),
"stride"_a = 1,
"padding"_a = 0,
"dilation"_a = 1,
"output_padding"_a = 0,
"groups"_a = 1,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def conv_transpose2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
"def conv_transpose2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, output_padding: Union[int, Tuple[int, int]] = 0, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
2D transposed convolution over an input with several channels
@@ -3689,6 +3707,9 @@ void init_ops(nb::module_& m) {
dilation (int or tuple(int), optional): :obj:`tuple` of size 2 with
kernel dilation. All spatial dimensions get the same dilation
if only one number is specified. Default: ``1``
output_padding (int or tuple(int), optional): :obj:`tuple` of size 2 with
output padding. All spatial dimensions get the same output
padding if only one number is specified. Default: ``0``.
groups (int, optional): input feature groups. Default: ``1``.
Returns:
@@ -3701,11 +3722,13 @@ void init_ops(nb::module_& m) {
const std::variant<int, std::tuple<int, int, int>>& stride,
const std::variant<int, std::tuple<int, int, int>>& padding,
const std::variant<int, std::tuple<int, int, int>>& dilation,
const std::variant<int, std::tuple<int, int, int>>& output_padding,
int groups,
mx::StreamOrDevice s) {
std::tuple<int, int, int> stride_tuple{1, 1, 1};
std::tuple<int, int, int> padding_tuple{0, 0, 0};
std::tuple<int, int, int> dilation_tuple{1, 1, 1};
std::tuple<int, int, int> output_padding_tuple{0, 0, 0};
if (auto pv = std::get_if<int>(&stride); pv) {
stride_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};
@@ -3725,12 +3748,20 @@ void init_ops(nb::module_& m) {
dilation_tuple = std::get<std::tuple<int, int, int>>(dilation);
}
if (auto pv = std::get_if<int>(&output_padding); pv) {
output_padding_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};
} else {
output_padding_tuple =
std::get<std::tuple<int, int, int>>(output_padding);
}
return mx::conv_transpose3d(
input,
weight,
stride_tuple,
padding_tuple,
dilation_tuple,
output_padding_tuple,
groups,
s);
},
@@ -3739,11 +3770,12 @@ void init_ops(nb::module_& m) {
"stride"_a = 1,
"padding"_a = 0,
"dilation"_a = 1,
"output_padding"_a = 0,
"groups"_a = 1,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def conv_transpose3d(input: array, weight: array, /, stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, dilation: Union[int, Tuple[int, int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
"def conv_transpose3d(input: array, weight: array, /, stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, dilation: Union[int, Tuple[int, int, int]] = 1, output_padding: Union[int, Tuple[int, int, int]] = 0, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
3D transposed convolution over an input with several channels
@@ -3761,6 +3793,9 @@ void init_ops(nb::module_& m) {
dilation (int or tuple(int), optional): :obj:`tuple` of size 3 with
kernel dilation. All spatial dimensions get the same dilation
if only one number is specified. Default: ``1``
output_padding (int or tuple(int), optional): :obj:`tuple` of size 3 with
output padding. All spatial dimensions get the same output
padding if only one number is specified. Default: ``0``.
groups (int, optional): input feature groups. Default: ``1``.
Returns:

View File

@@ -596,6 +596,215 @@ class TestConvTranspose(mlx_tests.MLXTestCase):
N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype
)
@unittest.skipIf(not has_torch, "requires Torch")
def test_torch_conv_tranpose_1d_output_padding(self):
def run_conv_transpose_1d_output_padding(
N, C, O, iH, kH, stride, padding, output_padding, dtype="float32", atol=1e-5
):
with self.subTest(
dtype=dtype,
N=N,
C=C,
O=O,
iH=iH,
kH=kH,
stride=stride,
padding=padding,
output_padding=output_padding,
):
np_dtype = getattr(np, dtype)
np.random.seed(0)
in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype)
wt_np = np.random.normal(0, 1.0 / C, (O, kH, C)).astype(np_dtype)
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
in_pt = torch.from_numpy(in_np.transpose(0, 2, 1))
wt_pt = torch.from_numpy(wt_np.transpose(2, 0, 1))
out_mx = mx.conv_transpose1d(
in_mx,
wt_mx,
stride=stride,
padding=padding,
output_padding=output_padding,
)
out_pt = torch.conv_transpose1d(
in_pt,
wt_pt,
stride=stride,
padding=padding,
output_padding=output_padding,
)
out_pt = torch.transpose(out_pt, 2, 1)
self.assertEqual(out_pt.shape, out_mx.shape)
self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=atol))
for dtype in ("float32",):
for N, C, O in ((1, 1, 1), (1, 6, 1), (4, 32, 64)):
for iH, kH, stride, padding, output_padding in (
(3, 2, 2, 0, 1),
(5, 3, 2, 1, 0),
(7, 4, 3, 1, 2),
):
run_conv_transpose_1d_output_padding(
N, C, O, iH, kH, stride, padding, output_padding, dtype=dtype
)
@unittest.skipIf(not has_torch, "requires Torch")
def test_torch_conv_transpose_2d_output_padding(self):
def run_conv_transpose_2d_output_padding(
N,
C,
O,
idim,
kdim,
stride,
padding,
output_padding,
dtype="float32",
atol=1e-5,
):
with self.subTest(
dtype=dtype,
N=N,
C=C,
O=O,
idim=idim,
kdim=kdim,
stride=stride,
padding=padding,
output_padding=output_padding,
):
np_dtype = getattr(np, dtype)
np.random.seed(0)
iH, iW = idim
kH, kW = kdim
in_np = np.random.normal(0, 1.0 / C, (N, iH, iW, C)).astype(np_dtype)
wt_np = np.random.normal(0, 1.0 / C, (O, kH, kW, C)).astype(np_dtype)
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
in_pt = torch.from_numpy(in_np.transpose(0, 3, 1, 2))
wt_pt = torch.from_numpy(wt_np.transpose(3, 0, 1, 2))
out_mx = mx.conv_transpose2d(
in_mx,
wt_mx,
stride=stride,
padding=padding,
output_padding=output_padding,
)
out_pt = torch.conv_transpose2d(
in_pt,
wt_pt,
stride=stride,
padding=padding,
output_padding=output_padding,
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True)
self.assertEqual(out_pt.shape, out_mx.shape)
self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
for dtype in ("float32",):
for N, C, O in ((1, 1, 1), (1, 6, 1), (4, 32, 64)):
for idim, kdim, stride, padding, output_padding in (
((3, 3), (2, 2), (2, 2), (0, 0), (1, 1)),
((5, 5), (3, 3), (2, 2), (1, 1), (0, 0)),
((7, 7), (4, 4), (3, 3), (1, 1), (2, 2)),
):
run_conv_transpose_2d_output_padding(
N,
C,
O,
idim,
kdim,
stride,
padding,
output_padding,
dtype=dtype,
)
@unittest.skipIf(not has_torch, "requires Torch")
def test_torch_conv_transpose_3d_output_padding(self):
def run_conv_transpose_3d_output_padding(
N,
C,
O,
idim,
kdim,
stride,
padding,
output_padding,
dtype="float32",
atol=1e-5,
):
with self.subTest(
dtype=dtype,
N=N,
C=C,
O=O,
idim=idim,
kdim=kdim,
stride=stride,
padding=padding,
output_padding=output_padding,
):
np_dtype = getattr(np, dtype)
np.random.seed(0)
iD, iH, iW = idim
kD, kH, kW = kdim
in_np = np.random.normal(0, 1.0 / C, (N, iD, iH, iW, C)).astype(
np_dtype
)
wt_np = np.random.normal(0, 1.0 / C, (O, kD, kH, kW, C)).astype(
np_dtype
)
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
in_pt = torch.from_numpy(in_np.transpose(0, 4, 1, 2, 3))
wt_pt = torch.from_numpy(wt_np.transpose(4, 0, 1, 2, 3))
out_mx = mx.conv_transpose3d(
in_mx,
wt_mx,
stride=stride,
padding=padding,
output_padding=output_padding,
)
out_pt = torch.conv_transpose3d(
in_pt,
wt_pt,
stride=stride,
padding=padding,
output_padding=output_padding,
)
out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1)).numpy(force=True)
self.assertEqual(out_pt.shape, out_mx.shape)
self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
for dtype in ("float32",):
for N, C, O in ((1, 1, 1), (1, 6, 1), (4, 32, 64)):
for idim, kdim, stride, padding, output_padding in (
((3, 3, 3), (2, 2, 2), (2, 2, 2), (0, 0, 0), (1, 1, 1)),
((5, 5, 5), (3, 3, 3), (2, 2, 2), (1, 1, 1), (0, 0, 0)),
((7, 7, 7), (4, 4, 4), (3, 3, 3), (1, 1, 1), (2, 2, 2)),
):
run_conv_transpose_3d_output_padding(
N,
C,
O,
idim,
kdim,
stride,
padding,
output_padding,
dtype=dtype,
)
if __name__ == "__main__":
unittest.main()