mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Added output_padding to nn.layers as well as tests for the same
This commit is contained in:
parent
59b934dcf9
commit
f2b5ba49af
@ -1291,6 +1291,7 @@ array conv_transpose1d(
|
||||
int stride = 1,
|
||||
int padding = 0,
|
||||
int dilation = 1,
|
||||
int output_padding = 0,
|
||||
int groups = 1,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
@ -1301,6 +1302,7 @@ array conv_transpose2d(
|
||||
const std::pair<int, int>& stride = {1, 1},
|
||||
const std::pair<int, int>& padding = {0, 0},
|
||||
const std::pair<int, int>& dilation = {1, 1},
|
||||
const std::pair<int, int>& output_padding = {0, 0},
|
||||
int groups = 1,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
@ -1311,6 +1313,7 @@ array conv_transpose3d(
|
||||
const std::tuple<int, int, int>& stride = {1, 1, 1},
|
||||
const std::tuple<int, int, int>& padding = {0, 0, 0},
|
||||
const std::tuple<int, int, int>& dilation = {1, 1, 1},
|
||||
const std::tuple<int, int, int>& output_padding = {0, 0, 0},
|
||||
int groups = 1,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
|
@ -25,6 +25,8 @@ class ConvTranspose1d(Module):
|
||||
padding (int, optional): How many positions to 0-pad the input with.
|
||||
Default: ``0``.
|
||||
dilation (int, optional): The dilation of the convolution.
|
||||
output_padding(int, optional): Additional size added to one side of the output shape.
|
||||
Default: ``0``.
|
||||
bias (bool, optional): If ``True`` add a learnable bias to the output.
|
||||
Default: ``True``
|
||||
"""
|
||||
@ -37,6 +39,7 @@ class ConvTranspose1d(Module):
|
||||
stride: int = 1,
|
||||
padding: int = 0,
|
||||
dilation: int = 1,
|
||||
output_padding: int = 0,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
@ -53,18 +56,25 @@ class ConvTranspose1d(Module):
|
||||
self.padding = padding
|
||||
self.dilation = dilation
|
||||
self.stride = stride
|
||||
self.output_padding = output_padding
|
||||
|
||||
def _extra_repr(self):
|
||||
return (
|
||||
f"{self.weight.shape[-1]}, {self.weight.shape[0]}, "
|
||||
f"kernel_size={self.weight.shape[1]}, stride={self.stride}, "
|
||||
f"padding={self.padding}, dilation={self.dilation}, "
|
||||
f"output_padding={self.output_padding}, "
|
||||
f"bias={'bias' in self}"
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
y = mx.conv_transpose1d(
|
||||
x, self.weight, self.stride, self.padding, self.dilation
|
||||
x,
|
||||
self.weight,
|
||||
self.stride,
|
||||
self.padding,
|
||||
self.dilation,
|
||||
self.output_padding,
|
||||
)
|
||||
if "bias" in self:
|
||||
y = y + self.bias
|
||||
@ -90,6 +100,8 @@ class ConvTranspose2d(Module):
|
||||
padding (int or tuple, optional): How many positions to 0-pad
|
||||
the input with. Default: ``0``.
|
||||
dilation (int or tuple, optional): The dilation of the convolution.
|
||||
output_padding(int or tuple, optional): Additional size added to one side of the output shape.
|
||||
Default: ``0``.
|
||||
bias (bool, optional): If ``True`` add a learnable bias to the
|
||||
output. Default: ``True``
|
||||
"""
|
||||
@ -102,13 +114,14 @@ class ConvTranspose2d(Module):
|
||||
stride: Union[int, tuple] = 1,
|
||||
padding: Union[int, tuple] = 0,
|
||||
dilation: Union[int, tuple] = 1,
|
||||
output_padding: Union[int, tuple] = 0,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
kernel_size, stride, padding = map(
|
||||
kernel_size, stride, padding, output_padding = map(
|
||||
lambda x: (x, x) if isinstance(x, int) else x,
|
||||
(kernel_size, stride, padding),
|
||||
(kernel_size, stride, padding, output_padding),
|
||||
)
|
||||
scale = math.sqrt(1 / (in_channels * kernel_size[0] * kernel_size[1]))
|
||||
self.weight = mx.random.uniform(
|
||||
@ -122,18 +135,25 @@ class ConvTranspose2d(Module):
|
||||
self.padding = padding
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
self.output_padding = output_padding
|
||||
|
||||
def _extra_repr(self):
|
||||
return (
|
||||
f"{self.weight.shape[-1]}, {self.weight.shape[0]}, "
|
||||
f"kernel_size={self.weight.shape[1:2]}, stride={self.stride}, "
|
||||
f"padding={self.padding}, dilation={self.dilation}, "
|
||||
f"output_padding={self.output_padding}, "
|
||||
f"bias={'bias' in self}"
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
y = mx.conv_transpose2d(
|
||||
x, self.weight, self.stride, self.padding, self.dilation
|
||||
x,
|
||||
self.weight,
|
||||
self.stride,
|
||||
self.padding,
|
||||
self.dilation,
|
||||
self.output_padding,
|
||||
)
|
||||
if "bias" in self:
|
||||
y = y + self.bias
|
||||
@ -160,6 +180,8 @@ class ConvTranspose3d(Module):
|
||||
padding (int or tuple, optional): How many positions to 0-pad
|
||||
the input with. Default: ``0``.
|
||||
dilation (int or tuple, optional): The dilation of the convolution.
|
||||
output_padding(int or tuple, optional): Additional size added to one side of the output shape.
|
||||
Default: ``0``.
|
||||
bias (bool, optional): If ``True`` add a learnable bias to the
|
||||
output. Default: ``True``
|
||||
"""
|
||||
@ -172,13 +194,14 @@ class ConvTranspose3d(Module):
|
||||
stride: Union[int, tuple] = 1,
|
||||
padding: Union[int, tuple] = 0,
|
||||
dilation: Union[int, tuple] = 1,
|
||||
output_padding: Union[int, tuple] = 0,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
kernel_size, stride, padding = map(
|
||||
kernel_size, stride, padding, output_padding = map(
|
||||
lambda x: (x, x, x) if isinstance(x, int) else x,
|
||||
(kernel_size, stride, padding),
|
||||
(kernel_size, stride, padding, output_padding),
|
||||
)
|
||||
scale = math.sqrt(
|
||||
1 / (in_channels * kernel_size[0] * kernel_size[1] * kernel_size[2])
|
||||
@ -194,18 +217,25 @@ class ConvTranspose3d(Module):
|
||||
self.padding = padding
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
self.output_padding = output_padding
|
||||
|
||||
def _extra_repr(self):
|
||||
return (
|
||||
f"{self.weight.shape[-1]}, {self.weight.shape[0]}, "
|
||||
f"kernel_size={self.weight.shape[1:3]}, stride={self.stride}, "
|
||||
f"padding={self.padding}, dilation={self.dilation}, "
|
||||
f"output_padding={self.output_padding}, "
|
||||
f"bias={'bias' in self}"
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
y = mx.conv_transpose3d(
|
||||
x, self.weight, self.stride, self.padding, self.dilation
|
||||
x,
|
||||
self.weight,
|
||||
self.stride,
|
||||
self.padding,
|
||||
self.dilation,
|
||||
self.output_padding,
|
||||
)
|
||||
if "bias" in self:
|
||||
y = y + self.bias
|
||||
|
@ -596,6 +596,215 @@ class TestConvTranspose(mlx_tests.MLXTestCase):
|
||||
N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype
|
||||
)
|
||||
|
||||
@unittest.skipIf(not has_torch, "requires Torch")
|
||||
def test_torch_conv_tranpose_1d_output_padding(self):
|
||||
def run_conv_transpose_1d_output_padding(
|
||||
N, C, O, iH, kH, stride, padding, output_padding, dtype="float32", atol=1e-5
|
||||
):
|
||||
with self.subTest(
|
||||
dtype=dtype,
|
||||
N=N,
|
||||
C=C,
|
||||
O=O,
|
||||
iH=iH,
|
||||
kH=kH,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
):
|
||||
np_dtype = getattr(np, dtype)
|
||||
np.random.seed(0)
|
||||
in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype)
|
||||
wt_np = np.random.normal(0, 1.0 / C, (O, kH, C)).astype(np_dtype)
|
||||
|
||||
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
||||
in_pt = torch.from_numpy(in_np.transpose(0, 2, 1))
|
||||
wt_pt = torch.from_numpy(wt_np.transpose(2, 0, 1))
|
||||
|
||||
out_mx = mx.conv_transpose1d(
|
||||
in_mx,
|
||||
wt_mx,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
)
|
||||
|
||||
out_pt = torch.conv_transpose1d(
|
||||
in_pt,
|
||||
wt_pt,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
)
|
||||
out_pt = torch.transpose(out_pt, 2, 1)
|
||||
|
||||
self.assertEqual(out_pt.shape, out_mx.shape)
|
||||
self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=atol))
|
||||
|
||||
for dtype in ("float32",):
|
||||
for N, C, O in ((1, 1, 1), (1, 6, 1), (4, 32, 64)):
|
||||
for iH, kH, stride, padding, output_padding in (
|
||||
(3, 2, 2, 0, 1),
|
||||
(5, 3, 2, 1, 0),
|
||||
(7, 4, 3, 1, 2),
|
||||
):
|
||||
run_conv_transpose_1d_output_padding(
|
||||
N, C, O, iH, kH, stride, padding, output_padding, dtype=dtype
|
||||
)
|
||||
|
||||
@unittest.skipIf(not has_torch, "requires Torch")
|
||||
def test_torch_conv_transpose_2d_output_padding(self):
|
||||
def run_conv_transpose_2d_output_padding(
|
||||
N,
|
||||
C,
|
||||
O,
|
||||
idim,
|
||||
kdim,
|
||||
stride,
|
||||
padding,
|
||||
output_padding,
|
||||
dtype="float32",
|
||||
atol=1e-5,
|
||||
):
|
||||
with self.subTest(
|
||||
dtype=dtype,
|
||||
N=N,
|
||||
C=C,
|
||||
O=O,
|
||||
idim=idim,
|
||||
kdim=kdim,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
):
|
||||
np_dtype = getattr(np, dtype)
|
||||
np.random.seed(0)
|
||||
iH, iW = idim
|
||||
kH, kW = kdim
|
||||
in_np = np.random.normal(0, 1.0 / C, (N, iH, iW, C)).astype(np_dtype)
|
||||
wt_np = np.random.normal(0, 1.0 / C, (O, kH, kW, C)).astype(np_dtype)
|
||||
|
||||
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
||||
in_pt = torch.from_numpy(in_np.transpose(0, 3, 1, 2))
|
||||
wt_pt = torch.from_numpy(wt_np.transpose(3, 0, 1, 2))
|
||||
|
||||
out_mx = mx.conv_transpose2d(
|
||||
in_mx,
|
||||
wt_mx,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
)
|
||||
|
||||
out_pt = torch.conv_transpose2d(
|
||||
in_pt,
|
||||
wt_pt,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True)
|
||||
|
||||
self.assertEqual(out_pt.shape, out_mx.shape)
|
||||
self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
|
||||
|
||||
for dtype in ("float32",):
|
||||
for N, C, O in ((1, 1, 1), (1, 6, 1), (4, 32, 64)):
|
||||
for idim, kdim, stride, padding, output_padding in (
|
||||
((3, 3), (2, 2), (2, 2), (0, 0), (1, 1)),
|
||||
((5, 5), (3, 3), (2, 2), (1, 1), (0, 0)),
|
||||
((7, 7), (4, 4), (3, 3), (1, 1), (2, 2)),
|
||||
):
|
||||
run_conv_transpose_2d_output_padding(
|
||||
N,
|
||||
C,
|
||||
O,
|
||||
idim,
|
||||
kdim,
|
||||
stride,
|
||||
padding,
|
||||
output_padding,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
@unittest.skipIf(not has_torch, "requires Torch")
|
||||
def test_torch_conv_transpose_3d_output_padding(self):
|
||||
def run_conv_transpose_3d_output_padding(
|
||||
N,
|
||||
C,
|
||||
O,
|
||||
idim,
|
||||
kdim,
|
||||
stride,
|
||||
padding,
|
||||
output_padding,
|
||||
dtype="float32",
|
||||
atol=1e-5,
|
||||
):
|
||||
with self.subTest(
|
||||
dtype=dtype,
|
||||
N=N,
|
||||
C=C,
|
||||
O=O,
|
||||
idim=idim,
|
||||
kdim=kdim,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
):
|
||||
np_dtype = getattr(np, dtype)
|
||||
np.random.seed(0)
|
||||
iD, iH, iW = idim
|
||||
kD, kH, kW = kdim
|
||||
in_np = np.random.normal(0, 1.0 / C, (N, iD, iH, iW, C)).astype(
|
||||
np_dtype
|
||||
)
|
||||
wt_np = np.random.normal(0, 1.0 / C, (O, kD, kH, kW, C)).astype(
|
||||
np_dtype
|
||||
)
|
||||
|
||||
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
||||
in_pt = torch.from_numpy(in_np.transpose(0, 4, 1, 2, 3))
|
||||
wt_pt = torch.from_numpy(wt_np.transpose(4, 0, 1, 2, 3))
|
||||
|
||||
out_mx = mx.conv_transpose3d(
|
||||
in_mx,
|
||||
wt_mx,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
)
|
||||
out_pt = torch.conv_transpose3d(
|
||||
in_pt,
|
||||
wt_pt,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1)).numpy(force=True)
|
||||
|
||||
self.assertEqual(out_pt.shape, out_mx.shape)
|
||||
self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
|
||||
|
||||
for dtype in ("float32",):
|
||||
for N, C, O in ((1, 1, 1), (1, 6, 1), (4, 32, 64)):
|
||||
for idim, kdim, stride, padding, output_padding in (
|
||||
((3, 3, 3), (2, 2, 2), (2, 2, 2), (0, 0, 0), (1, 1, 1)),
|
||||
((5, 5, 5), (3, 3, 3), (2, 2, 2), (1, 1, 1), (0, 0, 0)),
|
||||
((7, 7, 7), (4, 4, 4), (3, 3, 3), (1, 1, 1), (2, 2, 2)),
|
||||
):
|
||||
run_conv_transpose_3d_output_padding(
|
||||
N,
|
||||
C,
|
||||
O,
|
||||
idim,
|
||||
kdim,
|
||||
stride,
|
||||
padding,
|
||||
output_padding,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user