mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Added output_padding parameters in conv_transpose (#2092)
This commit is contained in:
parent
3836445241
commit
600e87e03c
13
mlx/ops.cpp
13
mlx/ops.cpp
@ -3769,6 +3769,7 @@ array conv_transpose_general(
|
|||||||
std::vector<int> stride,
|
std::vector<int> stride,
|
||||||
std::vector<int> padding,
|
std::vector<int> padding,
|
||||||
std::vector<int> dilation,
|
std::vector<int> dilation,
|
||||||
|
std::vector<int> output_padding,
|
||||||
int groups,
|
int groups,
|
||||||
StreamOrDevice s) {
|
StreamOrDevice s) {
|
||||||
std::vector<int> padding_lo(padding.size());
|
std::vector<int> padding_lo(padding.size());
|
||||||
@ -3782,7 +3783,8 @@ array conv_transpose_general(
|
|||||||
|
|
||||||
int in_size = 1 + (conv_output_shape - 1);
|
int in_size = 1 + (conv_output_shape - 1);
|
||||||
int out_size = 1 + stride[i] * (input.shape(1 + i) - 1);
|
int out_size = 1 + stride[i] * (input.shape(1 + i) - 1);
|
||||||
padding_hi[i] = in_size - out_size + padding[i];
|
padding_hi[i] = in_size - out_size + padding[i] +
|
||||||
|
output_padding[i]; // Adjust with output_padding
|
||||||
}
|
}
|
||||||
|
|
||||||
return conv_general(
|
return conv_general(
|
||||||
@ -3805,10 +3807,11 @@ array conv_transpose1d(
|
|||||||
int stride /* = 1 */,
|
int stride /* = 1 */,
|
||||||
int padding /* = 0 */,
|
int padding /* = 0 */,
|
||||||
int dilation /* = 1 */,
|
int dilation /* = 1 */,
|
||||||
|
int output_padding /* = 0 */,
|
||||||
int groups /* = 1 */,
|
int groups /* = 1 */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
return conv_transpose_general(
|
return conv_transpose_general(
|
||||||
in_, wt_, {stride}, {padding}, {dilation}, groups, s);
|
in_, wt_, {stride}, {padding}, {dilation}, {output_padding}, groups, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** 2D transposed convolution with a filter */
|
/** 2D transposed convolution with a filter */
|
||||||
@ -3818,6 +3821,7 @@ array conv_transpose2d(
|
|||||||
const std::pair<int, int>& stride /* = {1, 1} */,
|
const std::pair<int, int>& stride /* = {1, 1} */,
|
||||||
const std::pair<int, int>& padding /* = {0, 0} */,
|
const std::pair<int, int>& padding /* = {0, 0} */,
|
||||||
const std::pair<int, int>& dilation /* = {1, 1} */,
|
const std::pair<int, int>& dilation /* = {1, 1} */,
|
||||||
|
const std::pair<int, int>& output_padding /* = {0, 0} */,
|
||||||
int groups /* = 1 */,
|
int groups /* = 1 */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
return conv_transpose_general(
|
return conv_transpose_general(
|
||||||
@ -3826,6 +3830,7 @@ array conv_transpose2d(
|
|||||||
{stride.first, stride.second},
|
{stride.first, stride.second},
|
||||||
{padding.first, padding.second},
|
{padding.first, padding.second},
|
||||||
{dilation.first, dilation.second},
|
{dilation.first, dilation.second},
|
||||||
|
{output_padding.first, output_padding.second},
|
||||||
groups,
|
groups,
|
||||||
s);
|
s);
|
||||||
}
|
}
|
||||||
@ -3837,6 +3842,7 @@ array conv_transpose3d(
|
|||||||
const std::tuple<int, int, int>& stride /* = {1, 1, 1} */,
|
const std::tuple<int, int, int>& stride /* = {1, 1, 1} */,
|
||||||
const std::tuple<int, int, int>& padding /* = {0, 0, 0} */,
|
const std::tuple<int, int, int>& padding /* = {0, 0, 0} */,
|
||||||
const std::tuple<int, int, int>& dilation /* = {1, 1, 1} */,
|
const std::tuple<int, int, int>& dilation /* = {1, 1, 1} */,
|
||||||
|
const std::tuple<int, int, int>& output_padding /* = {0, 0, 0} */,
|
||||||
int groups /* = 1 */,
|
int groups /* = 1 */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
return conv_transpose_general(
|
return conv_transpose_general(
|
||||||
@ -3845,6 +3851,9 @@ array conv_transpose3d(
|
|||||||
{std::get<0>(stride), std::get<1>(stride), std::get<2>(stride)},
|
{std::get<0>(stride), std::get<1>(stride), std::get<2>(stride)},
|
||||||
{std::get<0>(padding), std::get<1>(padding), std::get<2>(padding)},
|
{std::get<0>(padding), std::get<1>(padding), std::get<2>(padding)},
|
||||||
{std::get<0>(dilation), std::get<1>(dilation), std::get<2>(dilation)},
|
{std::get<0>(dilation), std::get<1>(dilation), std::get<2>(dilation)},
|
||||||
|
{std::get<0>(output_padding),
|
||||||
|
std::get<1>(output_padding),
|
||||||
|
std::get<2>(output_padding)},
|
||||||
groups,
|
groups,
|
||||||
s);
|
s);
|
||||||
}
|
}
|
||||||
|
@ -1291,6 +1291,7 @@ array conv_transpose1d(
|
|||||||
int stride = 1,
|
int stride = 1,
|
||||||
int padding = 0,
|
int padding = 0,
|
||||||
int dilation = 1,
|
int dilation = 1,
|
||||||
|
int output_padding = 0,
|
||||||
int groups = 1,
|
int groups = 1,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
@ -1301,6 +1302,7 @@ array conv_transpose2d(
|
|||||||
const std::pair<int, int>& stride = {1, 1},
|
const std::pair<int, int>& stride = {1, 1},
|
||||||
const std::pair<int, int>& padding = {0, 0},
|
const std::pair<int, int>& padding = {0, 0},
|
||||||
const std::pair<int, int>& dilation = {1, 1},
|
const std::pair<int, int>& dilation = {1, 1},
|
||||||
|
const std::pair<int, int>& output_padding = {0, 0},
|
||||||
int groups = 1,
|
int groups = 1,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
@ -1311,6 +1313,7 @@ array conv_transpose3d(
|
|||||||
const std::tuple<int, int, int>& stride = {1, 1, 1},
|
const std::tuple<int, int, int>& stride = {1, 1, 1},
|
||||||
const std::tuple<int, int, int>& padding = {0, 0, 0},
|
const std::tuple<int, int, int>& padding = {0, 0, 0},
|
||||||
const std::tuple<int, int, int>& dilation = {1, 1, 1},
|
const std::tuple<int, int, int>& dilation = {1, 1, 1},
|
||||||
|
const std::tuple<int, int, int>& output_padding = {0, 0, 0},
|
||||||
int groups = 1,
|
int groups = 1,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
@ -25,6 +25,8 @@ class ConvTranspose1d(Module):
|
|||||||
padding (int, optional): How many positions to 0-pad the input with.
|
padding (int, optional): How many positions to 0-pad the input with.
|
||||||
Default: ``0``.
|
Default: ``0``.
|
||||||
dilation (int, optional): The dilation of the convolution.
|
dilation (int, optional): The dilation of the convolution.
|
||||||
|
output_padding(int, optional): Additional size added to one side of the
|
||||||
|
output shape. Default: ``0``.
|
||||||
bias (bool, optional): If ``True`` add a learnable bias to the output.
|
bias (bool, optional): If ``True`` add a learnable bias to the output.
|
||||||
Default: ``True``
|
Default: ``True``
|
||||||
"""
|
"""
|
||||||
@ -37,6 +39,7 @@ class ConvTranspose1d(Module):
|
|||||||
stride: int = 1,
|
stride: int = 1,
|
||||||
padding: int = 0,
|
padding: int = 0,
|
||||||
dilation: int = 1,
|
dilation: int = 1,
|
||||||
|
output_padding: int = 0,
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -53,18 +56,25 @@ class ConvTranspose1d(Module):
|
|||||||
self.padding = padding
|
self.padding = padding
|
||||||
self.dilation = dilation
|
self.dilation = dilation
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
|
self.output_padding = output_padding
|
||||||
|
|
||||||
def _extra_repr(self):
|
def _extra_repr(self):
|
||||||
return (
|
return (
|
||||||
f"{self.weight.shape[-1]}, {self.weight.shape[0]}, "
|
f"{self.weight.shape[-1]}, {self.weight.shape[0]}, "
|
||||||
f"kernel_size={self.weight.shape[1]}, stride={self.stride}, "
|
f"kernel_size={self.weight.shape[1]}, stride={self.stride}, "
|
||||||
f"padding={self.padding}, dilation={self.dilation}, "
|
f"padding={self.padding}, dilation={self.dilation}, "
|
||||||
|
f"output_padding={self.output_padding}, "
|
||||||
f"bias={'bias' in self}"
|
f"bias={'bias' in self}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
y = mx.conv_transpose1d(
|
y = mx.conv_transpose1d(
|
||||||
x, self.weight, self.stride, self.padding, self.dilation
|
x,
|
||||||
|
self.weight,
|
||||||
|
self.stride,
|
||||||
|
self.padding,
|
||||||
|
self.dilation,
|
||||||
|
self.output_padding,
|
||||||
)
|
)
|
||||||
if "bias" in self:
|
if "bias" in self:
|
||||||
y = y + self.bias
|
y = y + self.bias
|
||||||
@ -90,6 +100,8 @@ class ConvTranspose2d(Module):
|
|||||||
padding (int or tuple, optional): How many positions to 0-pad
|
padding (int or tuple, optional): How many positions to 0-pad
|
||||||
the input with. Default: ``0``.
|
the input with. Default: ``0``.
|
||||||
dilation (int or tuple, optional): The dilation of the convolution.
|
dilation (int or tuple, optional): The dilation of the convolution.
|
||||||
|
output_padding(int or tuple, optional): Additional size added to one
|
||||||
|
side of the output shape. Default: ``0``.
|
||||||
bias (bool, optional): If ``True`` add a learnable bias to the
|
bias (bool, optional): If ``True`` add a learnable bias to the
|
||||||
output. Default: ``True``
|
output. Default: ``True``
|
||||||
"""
|
"""
|
||||||
@ -102,13 +114,14 @@ class ConvTranspose2d(Module):
|
|||||||
stride: Union[int, tuple] = 1,
|
stride: Union[int, tuple] = 1,
|
||||||
padding: Union[int, tuple] = 0,
|
padding: Union[int, tuple] = 0,
|
||||||
dilation: Union[int, tuple] = 1,
|
dilation: Union[int, tuple] = 1,
|
||||||
|
output_padding: Union[int, tuple] = 0,
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
kernel_size, stride, padding = map(
|
kernel_size, stride, padding, output_padding = map(
|
||||||
lambda x: (x, x) if isinstance(x, int) else x,
|
lambda x: (x, x) if isinstance(x, int) else x,
|
||||||
(kernel_size, stride, padding),
|
(kernel_size, stride, padding, output_padding),
|
||||||
)
|
)
|
||||||
scale = math.sqrt(1 / (in_channels * kernel_size[0] * kernel_size[1]))
|
scale = math.sqrt(1 / (in_channels * kernel_size[0] * kernel_size[1]))
|
||||||
self.weight = mx.random.uniform(
|
self.weight = mx.random.uniform(
|
||||||
@ -122,18 +135,25 @@ class ConvTranspose2d(Module):
|
|||||||
self.padding = padding
|
self.padding = padding
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
self.dilation = dilation
|
self.dilation = dilation
|
||||||
|
self.output_padding = output_padding
|
||||||
|
|
||||||
def _extra_repr(self):
|
def _extra_repr(self):
|
||||||
return (
|
return (
|
||||||
f"{self.weight.shape[-1]}, {self.weight.shape[0]}, "
|
f"{self.weight.shape[-1]}, {self.weight.shape[0]}, "
|
||||||
f"kernel_size={self.weight.shape[1:2]}, stride={self.stride}, "
|
f"kernel_size={self.weight.shape[1:2]}, stride={self.stride}, "
|
||||||
f"padding={self.padding}, dilation={self.dilation}, "
|
f"padding={self.padding}, dilation={self.dilation}, "
|
||||||
|
f"output_padding={self.output_padding}, "
|
||||||
f"bias={'bias' in self}"
|
f"bias={'bias' in self}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
y = mx.conv_transpose2d(
|
y = mx.conv_transpose2d(
|
||||||
x, self.weight, self.stride, self.padding, self.dilation
|
x,
|
||||||
|
self.weight,
|
||||||
|
self.stride,
|
||||||
|
self.padding,
|
||||||
|
self.dilation,
|
||||||
|
self.output_padding,
|
||||||
)
|
)
|
||||||
if "bias" in self:
|
if "bias" in self:
|
||||||
y = y + self.bias
|
y = y + self.bias
|
||||||
@ -160,6 +180,8 @@ class ConvTranspose3d(Module):
|
|||||||
padding (int or tuple, optional): How many positions to 0-pad
|
padding (int or tuple, optional): How many positions to 0-pad
|
||||||
the input with. Default: ``0``.
|
the input with. Default: ``0``.
|
||||||
dilation (int or tuple, optional): The dilation of the convolution.
|
dilation (int or tuple, optional): The dilation of the convolution.
|
||||||
|
output_padding(int or tuple, optional): Additional size added to one
|
||||||
|
side of the output shape. Default: ``0``.
|
||||||
bias (bool, optional): If ``True`` add a learnable bias to the
|
bias (bool, optional): If ``True`` add a learnable bias to the
|
||||||
output. Default: ``True``
|
output. Default: ``True``
|
||||||
"""
|
"""
|
||||||
@ -172,13 +194,14 @@ class ConvTranspose3d(Module):
|
|||||||
stride: Union[int, tuple] = 1,
|
stride: Union[int, tuple] = 1,
|
||||||
padding: Union[int, tuple] = 0,
|
padding: Union[int, tuple] = 0,
|
||||||
dilation: Union[int, tuple] = 1,
|
dilation: Union[int, tuple] = 1,
|
||||||
|
output_padding: Union[int, tuple] = 0,
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
kernel_size, stride, padding = map(
|
kernel_size, stride, padding, output_padding = map(
|
||||||
lambda x: (x, x, x) if isinstance(x, int) else x,
|
lambda x: (x, x, x) if isinstance(x, int) else x,
|
||||||
(kernel_size, stride, padding),
|
(kernel_size, stride, padding, output_padding),
|
||||||
)
|
)
|
||||||
scale = math.sqrt(
|
scale = math.sqrt(
|
||||||
1 / (in_channels * kernel_size[0] * kernel_size[1] * kernel_size[2])
|
1 / (in_channels * kernel_size[0] * kernel_size[1] * kernel_size[2])
|
||||||
@ -194,18 +217,25 @@ class ConvTranspose3d(Module):
|
|||||||
self.padding = padding
|
self.padding = padding
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
self.dilation = dilation
|
self.dilation = dilation
|
||||||
|
self.output_padding = output_padding
|
||||||
|
|
||||||
def _extra_repr(self):
|
def _extra_repr(self):
|
||||||
return (
|
return (
|
||||||
f"{self.weight.shape[-1]}, {self.weight.shape[0]}, "
|
f"{self.weight.shape[-1]}, {self.weight.shape[0]}, "
|
||||||
f"kernel_size={self.weight.shape[1:3]}, stride={self.stride}, "
|
f"kernel_size={self.weight.shape[1:3]}, stride={self.stride}, "
|
||||||
f"padding={self.padding}, dilation={self.dilation}, "
|
f"padding={self.padding}, dilation={self.dilation}, "
|
||||||
|
f"output_padding={self.output_padding}, "
|
||||||
f"bias={'bias' in self}"
|
f"bias={'bias' in self}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
y = mx.conv_transpose3d(
|
y = mx.conv_transpose3d(
|
||||||
x, self.weight, self.stride, self.padding, self.dilation
|
x,
|
||||||
|
self.weight,
|
||||||
|
self.stride,
|
||||||
|
self.padding,
|
||||||
|
self.dilation,
|
||||||
|
self.output_padding,
|
||||||
)
|
)
|
||||||
if "bias" in self:
|
if "bias" in self:
|
||||||
y = y + self.bias
|
y = y + self.bias
|
||||||
|
@ -3609,11 +3609,12 @@ void init_ops(nb::module_& m) {
|
|||||||
"stride"_a = 1,
|
"stride"_a = 1,
|
||||||
"padding"_a = 0,
|
"padding"_a = 0,
|
||||||
"dilation"_a = 1,
|
"dilation"_a = 1,
|
||||||
|
"output_padding"_a = 0,
|
||||||
"groups"_a = 1,
|
"groups"_a = 1,
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
"def conv_transpose1d(input: array, weight: array, /, stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
|
"def conv_transpose1d(input: array, weight: array, /, stride: int = 1, padding: int = 0, dilation: int = 1, output_padding: int = 0, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
1D transposed convolution over an input with several channels
|
1D transposed convolution over an input with several channels
|
||||||
|
|
||||||
@ -3623,6 +3624,7 @@ void init_ops(nb::module_& m) {
|
|||||||
stride (int, optional): Kernel stride. Default: ``1``.
|
stride (int, optional): Kernel stride. Default: ``1``.
|
||||||
padding (int, optional): Input padding. Default: ``0``.
|
padding (int, optional): Input padding. Default: ``0``.
|
||||||
dilation (int, optional): Kernel dilation. Default: ``1``.
|
dilation (int, optional): Kernel dilation. Default: ``1``.
|
||||||
|
output_padding (int, optional): Output padding. Default: ``0``.
|
||||||
groups (int, optional): Input feature groups. Default: ``1``.
|
groups (int, optional): Input feature groups. Default: ``1``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -3635,11 +3637,13 @@ void init_ops(nb::module_& m) {
|
|||||||
const std::variant<int, std::pair<int, int>>& stride,
|
const std::variant<int, std::pair<int, int>>& stride,
|
||||||
const std::variant<int, std::pair<int, int>>& padding,
|
const std::variant<int, std::pair<int, int>>& padding,
|
||||||
const std::variant<int, std::pair<int, int>>& dilation,
|
const std::variant<int, std::pair<int, int>>& dilation,
|
||||||
|
const std::variant<int, std::pair<int, int>>& output_padding,
|
||||||
int groups,
|
int groups,
|
||||||
mx::StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
std::pair<int, int> stride_pair{1, 1};
|
std::pair<int, int> stride_pair{1, 1};
|
||||||
std::pair<int, int> padding_pair{0, 0};
|
std::pair<int, int> padding_pair{0, 0};
|
||||||
std::pair<int, int> dilation_pair{1, 1};
|
std::pair<int, int> dilation_pair{1, 1};
|
||||||
|
std::pair<int, int> output_padding_pair{0, 0};
|
||||||
|
|
||||||
if (auto pv = std::get_if<int>(&stride); pv) {
|
if (auto pv = std::get_if<int>(&stride); pv) {
|
||||||
stride_pair = std::pair<int, int>{*pv, *pv};
|
stride_pair = std::pair<int, int>{*pv, *pv};
|
||||||
@ -3659,19 +3663,33 @@ void init_ops(nb::module_& m) {
|
|||||||
dilation_pair = std::get<std::pair<int, int>>(dilation);
|
dilation_pair = std::get<std::pair<int, int>>(dilation);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (auto pv = std::get_if<int>(&output_padding); pv) {
|
||||||
|
output_padding_pair = std::pair<int, int>{*pv, *pv};
|
||||||
|
} else {
|
||||||
|
output_padding_pair = std::get<std::pair<int, int>>(output_padding);
|
||||||
|
}
|
||||||
|
|
||||||
return mx::conv_transpose2d(
|
return mx::conv_transpose2d(
|
||||||
input, weight, stride_pair, padding_pair, dilation_pair, groups, s);
|
input,
|
||||||
|
weight,
|
||||||
|
stride_pair,
|
||||||
|
padding_pair,
|
||||||
|
dilation_pair,
|
||||||
|
output_padding_pair,
|
||||||
|
groups,
|
||||||
|
s);
|
||||||
},
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
"stride"_a = 1,
|
"stride"_a = 1,
|
||||||
"padding"_a = 0,
|
"padding"_a = 0,
|
||||||
"dilation"_a = 1,
|
"dilation"_a = 1,
|
||||||
|
"output_padding"_a = 0,
|
||||||
"groups"_a = 1,
|
"groups"_a = 1,
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
"def conv_transpose2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
|
"def conv_transpose2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, output_padding: Union[int, Tuple[int, int]] = 0, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
2D transposed convolution over an input with several channels
|
2D transposed convolution over an input with several channels
|
||||||
|
|
||||||
@ -3689,6 +3707,9 @@ void init_ops(nb::module_& m) {
|
|||||||
dilation (int or tuple(int), optional): :obj:`tuple` of size 2 with
|
dilation (int or tuple(int), optional): :obj:`tuple` of size 2 with
|
||||||
kernel dilation. All spatial dimensions get the same dilation
|
kernel dilation. All spatial dimensions get the same dilation
|
||||||
if only one number is specified. Default: ``1``
|
if only one number is specified. Default: ``1``
|
||||||
|
output_padding (int or tuple(int), optional): :obj:`tuple` of size 2 with
|
||||||
|
output padding. All spatial dimensions get the same output
|
||||||
|
padding if only one number is specified. Default: ``0``.
|
||||||
groups (int, optional): input feature groups. Default: ``1``.
|
groups (int, optional): input feature groups. Default: ``1``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -3701,11 +3722,13 @@ void init_ops(nb::module_& m) {
|
|||||||
const std::variant<int, std::tuple<int, int, int>>& stride,
|
const std::variant<int, std::tuple<int, int, int>>& stride,
|
||||||
const std::variant<int, std::tuple<int, int, int>>& padding,
|
const std::variant<int, std::tuple<int, int, int>>& padding,
|
||||||
const std::variant<int, std::tuple<int, int, int>>& dilation,
|
const std::variant<int, std::tuple<int, int, int>>& dilation,
|
||||||
|
const std::variant<int, std::tuple<int, int, int>>& output_padding,
|
||||||
int groups,
|
int groups,
|
||||||
mx::StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
std::tuple<int, int, int> stride_tuple{1, 1, 1};
|
std::tuple<int, int, int> stride_tuple{1, 1, 1};
|
||||||
std::tuple<int, int, int> padding_tuple{0, 0, 0};
|
std::tuple<int, int, int> padding_tuple{0, 0, 0};
|
||||||
std::tuple<int, int, int> dilation_tuple{1, 1, 1};
|
std::tuple<int, int, int> dilation_tuple{1, 1, 1};
|
||||||
|
std::tuple<int, int, int> output_padding_tuple{0, 0, 0};
|
||||||
|
|
||||||
if (auto pv = std::get_if<int>(&stride); pv) {
|
if (auto pv = std::get_if<int>(&stride); pv) {
|
||||||
stride_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};
|
stride_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};
|
||||||
@ -3725,12 +3748,20 @@ void init_ops(nb::module_& m) {
|
|||||||
dilation_tuple = std::get<std::tuple<int, int, int>>(dilation);
|
dilation_tuple = std::get<std::tuple<int, int, int>>(dilation);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (auto pv = std::get_if<int>(&output_padding); pv) {
|
||||||
|
output_padding_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};
|
||||||
|
} else {
|
||||||
|
output_padding_tuple =
|
||||||
|
std::get<std::tuple<int, int, int>>(output_padding);
|
||||||
|
}
|
||||||
|
|
||||||
return mx::conv_transpose3d(
|
return mx::conv_transpose3d(
|
||||||
input,
|
input,
|
||||||
weight,
|
weight,
|
||||||
stride_tuple,
|
stride_tuple,
|
||||||
padding_tuple,
|
padding_tuple,
|
||||||
dilation_tuple,
|
dilation_tuple,
|
||||||
|
output_padding_tuple,
|
||||||
groups,
|
groups,
|
||||||
s);
|
s);
|
||||||
},
|
},
|
||||||
@ -3739,11 +3770,12 @@ void init_ops(nb::module_& m) {
|
|||||||
"stride"_a = 1,
|
"stride"_a = 1,
|
||||||
"padding"_a = 0,
|
"padding"_a = 0,
|
||||||
"dilation"_a = 1,
|
"dilation"_a = 1,
|
||||||
|
"output_padding"_a = 0,
|
||||||
"groups"_a = 1,
|
"groups"_a = 1,
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
"def conv_transpose3d(input: array, weight: array, /, stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, dilation: Union[int, Tuple[int, int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
|
"def conv_transpose3d(input: array, weight: array, /, stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, dilation: Union[int, Tuple[int, int, int]] = 1, output_padding: Union[int, Tuple[int, int, int]] = 0, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
3D transposed convolution over an input with several channels
|
3D transposed convolution over an input with several channels
|
||||||
|
|
||||||
@ -3761,6 +3793,9 @@ void init_ops(nb::module_& m) {
|
|||||||
dilation (int or tuple(int), optional): :obj:`tuple` of size 3 with
|
dilation (int or tuple(int), optional): :obj:`tuple` of size 3 with
|
||||||
kernel dilation. All spatial dimensions get the same dilation
|
kernel dilation. All spatial dimensions get the same dilation
|
||||||
if only one number is specified. Default: ``1``
|
if only one number is specified. Default: ``1``
|
||||||
|
output_padding (int or tuple(int), optional): :obj:`tuple` of size 3 with
|
||||||
|
output padding. All spatial dimensions get the same output
|
||||||
|
padding if only one number is specified. Default: ``0``.
|
||||||
groups (int, optional): input feature groups. Default: ``1``.
|
groups (int, optional): input feature groups. Default: ``1``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -596,6 +596,215 @@ class TestConvTranspose(mlx_tests.MLXTestCase):
|
|||||||
N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype
|
N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@unittest.skipIf(not has_torch, "requires Torch")
|
||||||
|
def test_torch_conv_tranpose_1d_output_padding(self):
|
||||||
|
def run_conv_transpose_1d_output_padding(
|
||||||
|
N, C, O, iH, kH, stride, padding, output_padding, dtype="float32", atol=1e-5
|
||||||
|
):
|
||||||
|
with self.subTest(
|
||||||
|
dtype=dtype,
|
||||||
|
N=N,
|
||||||
|
C=C,
|
||||||
|
O=O,
|
||||||
|
iH=iH,
|
||||||
|
kH=kH,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
output_padding=output_padding,
|
||||||
|
):
|
||||||
|
np_dtype = getattr(np, dtype)
|
||||||
|
np.random.seed(0)
|
||||||
|
in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype)
|
||||||
|
wt_np = np.random.normal(0, 1.0 / C, (O, kH, C)).astype(np_dtype)
|
||||||
|
|
||||||
|
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
||||||
|
in_pt = torch.from_numpy(in_np.transpose(0, 2, 1))
|
||||||
|
wt_pt = torch.from_numpy(wt_np.transpose(2, 0, 1))
|
||||||
|
|
||||||
|
out_mx = mx.conv_transpose1d(
|
||||||
|
in_mx,
|
||||||
|
wt_mx,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
output_padding=output_padding,
|
||||||
|
)
|
||||||
|
|
||||||
|
out_pt = torch.conv_transpose1d(
|
||||||
|
in_pt,
|
||||||
|
wt_pt,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
output_padding=output_padding,
|
||||||
|
)
|
||||||
|
out_pt = torch.transpose(out_pt, 2, 1)
|
||||||
|
|
||||||
|
self.assertEqual(out_pt.shape, out_mx.shape)
|
||||||
|
self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=atol))
|
||||||
|
|
||||||
|
for dtype in ("float32",):
|
||||||
|
for N, C, O in ((1, 1, 1), (1, 6, 1), (4, 32, 64)):
|
||||||
|
for iH, kH, stride, padding, output_padding in (
|
||||||
|
(3, 2, 2, 0, 1),
|
||||||
|
(5, 3, 2, 1, 0),
|
||||||
|
(7, 4, 3, 1, 2),
|
||||||
|
):
|
||||||
|
run_conv_transpose_1d_output_padding(
|
||||||
|
N, C, O, iH, kH, stride, padding, output_padding, dtype=dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
@unittest.skipIf(not has_torch, "requires Torch")
|
||||||
|
def test_torch_conv_transpose_2d_output_padding(self):
|
||||||
|
def run_conv_transpose_2d_output_padding(
|
||||||
|
N,
|
||||||
|
C,
|
||||||
|
O,
|
||||||
|
idim,
|
||||||
|
kdim,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
output_padding,
|
||||||
|
dtype="float32",
|
||||||
|
atol=1e-5,
|
||||||
|
):
|
||||||
|
with self.subTest(
|
||||||
|
dtype=dtype,
|
||||||
|
N=N,
|
||||||
|
C=C,
|
||||||
|
O=O,
|
||||||
|
idim=idim,
|
||||||
|
kdim=kdim,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
output_padding=output_padding,
|
||||||
|
):
|
||||||
|
np_dtype = getattr(np, dtype)
|
||||||
|
np.random.seed(0)
|
||||||
|
iH, iW = idim
|
||||||
|
kH, kW = kdim
|
||||||
|
in_np = np.random.normal(0, 1.0 / C, (N, iH, iW, C)).astype(np_dtype)
|
||||||
|
wt_np = np.random.normal(0, 1.0 / C, (O, kH, kW, C)).astype(np_dtype)
|
||||||
|
|
||||||
|
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
||||||
|
in_pt = torch.from_numpy(in_np.transpose(0, 3, 1, 2))
|
||||||
|
wt_pt = torch.from_numpy(wt_np.transpose(3, 0, 1, 2))
|
||||||
|
|
||||||
|
out_mx = mx.conv_transpose2d(
|
||||||
|
in_mx,
|
||||||
|
wt_mx,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
output_padding=output_padding,
|
||||||
|
)
|
||||||
|
|
||||||
|
out_pt = torch.conv_transpose2d(
|
||||||
|
in_pt,
|
||||||
|
wt_pt,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
output_padding=output_padding,
|
||||||
|
)
|
||||||
|
out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True)
|
||||||
|
|
||||||
|
self.assertEqual(out_pt.shape, out_mx.shape)
|
||||||
|
self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
|
||||||
|
|
||||||
|
for dtype in ("float32",):
|
||||||
|
for N, C, O in ((1, 1, 1), (1, 6, 1), (4, 32, 64)):
|
||||||
|
for idim, kdim, stride, padding, output_padding in (
|
||||||
|
((3, 3), (2, 2), (2, 2), (0, 0), (1, 1)),
|
||||||
|
((5, 5), (3, 3), (2, 2), (1, 1), (0, 0)),
|
||||||
|
((7, 7), (4, 4), (3, 3), (1, 1), (2, 2)),
|
||||||
|
):
|
||||||
|
run_conv_transpose_2d_output_padding(
|
||||||
|
N,
|
||||||
|
C,
|
||||||
|
O,
|
||||||
|
idim,
|
||||||
|
kdim,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
output_padding,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
@unittest.skipIf(not has_torch, "requires Torch")
|
||||||
|
def test_torch_conv_transpose_3d_output_padding(self):
|
||||||
|
def run_conv_transpose_3d_output_padding(
|
||||||
|
N,
|
||||||
|
C,
|
||||||
|
O,
|
||||||
|
idim,
|
||||||
|
kdim,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
output_padding,
|
||||||
|
dtype="float32",
|
||||||
|
atol=1e-5,
|
||||||
|
):
|
||||||
|
with self.subTest(
|
||||||
|
dtype=dtype,
|
||||||
|
N=N,
|
||||||
|
C=C,
|
||||||
|
O=O,
|
||||||
|
idim=idim,
|
||||||
|
kdim=kdim,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
output_padding=output_padding,
|
||||||
|
):
|
||||||
|
np_dtype = getattr(np, dtype)
|
||||||
|
np.random.seed(0)
|
||||||
|
iD, iH, iW = idim
|
||||||
|
kD, kH, kW = kdim
|
||||||
|
in_np = np.random.normal(0, 1.0 / C, (N, iD, iH, iW, C)).astype(
|
||||||
|
np_dtype
|
||||||
|
)
|
||||||
|
wt_np = np.random.normal(0, 1.0 / C, (O, kD, kH, kW, C)).astype(
|
||||||
|
np_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
||||||
|
in_pt = torch.from_numpy(in_np.transpose(0, 4, 1, 2, 3))
|
||||||
|
wt_pt = torch.from_numpy(wt_np.transpose(4, 0, 1, 2, 3))
|
||||||
|
|
||||||
|
out_mx = mx.conv_transpose3d(
|
||||||
|
in_mx,
|
||||||
|
wt_mx,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
output_padding=output_padding,
|
||||||
|
)
|
||||||
|
out_pt = torch.conv_transpose3d(
|
||||||
|
in_pt,
|
||||||
|
wt_pt,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
output_padding=output_padding,
|
||||||
|
)
|
||||||
|
out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1)).numpy(force=True)
|
||||||
|
|
||||||
|
self.assertEqual(out_pt.shape, out_mx.shape)
|
||||||
|
self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
|
||||||
|
|
||||||
|
for dtype in ("float32",):
|
||||||
|
for N, C, O in ((1, 1, 1), (1, 6, 1), (4, 32, 64)):
|
||||||
|
for idim, kdim, stride, padding, output_padding in (
|
||||||
|
((3, 3, 3), (2, 2, 2), (2, 2, 2), (0, 0, 0), (1, 1, 1)),
|
||||||
|
((5, 5, 5), (3, 3, 3), (2, 2, 2), (1, 1, 1), (0, 0, 0)),
|
||||||
|
((7, 7, 7), (4, 4, 4), (3, 3, 3), (1, 1, 1), (2, 2, 2)),
|
||||||
|
):
|
||||||
|
run_conv_transpose_3d_output_padding(
|
||||||
|
N,
|
||||||
|
C,
|
||||||
|
O,
|
||||||
|
idim,
|
||||||
|
kdim,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
output_padding,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -3912,3 +3912,69 @@ TEST_CASE("test bitwise shift operations") {
|
|||||||
CHECK_EQ(right_shift_bool_result.dtype(), uint8);
|
CHECK_EQ(right_shift_bool_result.dtype(), uint8);
|
||||||
CHECK(array_equal(right_shift_bool_result, full({4}, 0, uint8)).item<bool>());
|
CHECK(array_equal(right_shift_bool_result, full({4}, 0, uint8)).item<bool>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test conv_transpose1d with output_padding") {
|
||||||
|
auto in = array({1.0, 2.0, 3.0}, {1, 1, 3});
|
||||||
|
auto wt = array({1.0, 1.0, 1.0}, {1, 1, 3});
|
||||||
|
int stride = 2;
|
||||||
|
int padding = 0;
|
||||||
|
int dilation = 1;
|
||||||
|
int output_padding = 1;
|
||||||
|
int groups = 1;
|
||||||
|
|
||||||
|
auto out = conv_transpose1d(
|
||||||
|
in, wt, stride, padding, dilation, output_padding, groups);
|
||||||
|
auto expected = array({6.0, 0.0}, {1, 2, 1});
|
||||||
|
CHECK(array_equal(out, expected).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test conv_transpose2d with output_padding") {
|
||||||
|
auto in = array({1.0, 2.0, 3.0, 4.0}, {1, 1, 2, 2});
|
||||||
|
auto wt = array({1.0, 1.0, 1.0, 1.0}, {2, 1, 1, 2});
|
||||||
|
std::pair<int, int> stride{2, 2};
|
||||||
|
std::pair<int, int> padding{0, 0};
|
||||||
|
std::pair<int, int> output_padding{1, 1};
|
||||||
|
std::pair<int, int> dilation{1, 1};
|
||||||
|
int groups = 1;
|
||||||
|
|
||||||
|
auto out = conv_transpose2d(
|
||||||
|
in, wt, stride, padding, dilation, output_padding, groups);
|
||||||
|
auto expected = array(
|
||||||
|
{3.0,
|
||||||
|
3.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
7.0,
|
||||||
|
7.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0},
|
||||||
|
{1, 2, 4, 2});
|
||||||
|
CHECK(array_equal(out, expected).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test conv_transpose3d with output_padding") {
|
||||||
|
auto in = array({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}, {1, 1, 2, 2, 2});
|
||||||
|
auto wt = array({1.0, 1.0}, {1, 1, 1, 1, 2});
|
||||||
|
std::tuple<int, int, int> stride{2, 2, 2};
|
||||||
|
std::tuple<int, int, int> padding{0, 0, 0};
|
||||||
|
std::tuple<int, int, int> output_padding{1, 1, 1};
|
||||||
|
std::tuple<int, int, int> dilation{1, 1, 1};
|
||||||
|
int groups = 1;
|
||||||
|
|
||||||
|
auto out = conv_transpose3d(
|
||||||
|
in, wt, stride, padding, dilation, output_padding, groups);
|
||||||
|
auto expected = array(
|
||||||
|
{3.0, 0.0, 7.0, 0.0, 0.0, 0.0, 0.0, 0.0, 11.0, 0.0, 15.0,
|
||||||
|
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0},
|
||||||
|
{1, 2, 4, 4, 1});
|
||||||
|
CHECK(array_equal(out, expected).item<bool>());
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user