mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-03 09:58:17 +08:00
Transposed Convolution (#1245)
* initial implementation for conv_transpose ran pre-commit implemented conv_transpose updated conv_general docstring updated conv_general docstring updated code comments removed commented run_conv_checks updated acknowledgments added missing entry to ops.rst added op to nn.layers resolved merge conflicts * removed ConvolutionTranspose primitive as suggested by reviewer removed ConvolutionTranspose primitive as suggested by reviewer * remove transpose flag, add another test --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
committed by
GitHub
parent
ba3e913c7a
commit
efeb9c0f02
@@ -3238,12 +3238,12 @@ void init_ops(nb::module_& m) {
|
||||
1D convolution over an input with several channels
|
||||
|
||||
Args:
|
||||
input (array): input array of shape (``N``, ``H``, ``C_in``)
|
||||
weight (array): weight array of shape (``C_out``, ``H``, ``C_in``)
|
||||
stride (int, optional): kernel stride. Default: ``1``.
|
||||
padding (int, optional): input padding. Default: ``0``.
|
||||
dilation (int, optional): kernel dilation. Default: ``1``.
|
||||
groups (int, optional): input feature groups. Default: ``1``.
|
||||
input (array): Input array of shape ``(N, H, C_in)``.
|
||||
weight (array): Weight array of shape ``(C_out, H, C_in)``.
|
||||
stride (int, optional): Kernel stride. Default: ``1``.
|
||||
padding (int, optional): Input padding. Default: ``0``.
|
||||
dilation (int, optional): Kernel dilation. Default: ``1``.
|
||||
groups (int, optional): Input feature groups. Default: ``1``.
|
||||
|
||||
Returns:
|
||||
array: The convolved array.
|
||||
@@ -3296,8 +3296,8 @@ void init_ops(nb::module_& m) {
|
||||
2D convolution over an input with several channels
|
||||
|
||||
Args:
|
||||
input (array): input array of shape ``(N, H, W, C_in)``
|
||||
weight (array): weight array of shape ``(C_out, H, W, C_in)``
|
||||
input (array): Input array of shape ``(N, H, W, C_in)``.
|
||||
weight (array): Weight array of shape ``(C_out, H, W, C_in)``.
|
||||
stride (int or tuple(int), optional): :obj:`tuple` of size 2 with
|
||||
kernel strides. All spatial dimensions get the same stride if
|
||||
only one number is specified. Default: ``1``.
|
||||
@@ -3368,8 +3368,173 @@ void init_ops(nb::module_& m) {
|
||||
Note: Only the default ``groups=1`` is currently supported.
|
||||
|
||||
Args:
|
||||
input (array): input array of shape ``(N, D, H, W, C_in)``
|
||||
weight (array): weight array of shape ``(C_out, D, H, W, C_in)``
|
||||
input (array): Input array of shape ``(N, D, H, W, C_in)``.
|
||||
weight (array): Weight array of shape ``(C_out, D, H, W, C_in)``.
|
||||
stride (int or tuple(int), optional): :obj:`tuple` of size 3 with
|
||||
kernel strides. All spatial dimensions get the same stride if
|
||||
only one number is specified. Default: ``1``.
|
||||
padding (int or tuple(int), optional): :obj:`tuple` of size 3 with
|
||||
symmetric input padding. All spatial dimensions get the same
|
||||
padding if only one number is specified. Default: ``0``.
|
||||
dilation (int or tuple(int), optional): :obj:`tuple` of size 3 with
|
||||
kernel dilation. All spatial dimensions get the same dilation
|
||||
if only one number is specified. Default: ``1``
|
||||
groups (int, optional): input feature groups. Default: ``1``.
|
||||
|
||||
Returns:
|
||||
array: The convolved array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"conv_transpose1d",
|
||||
&conv_transpose1d,
|
||||
nb::arg(),
|
||||
nb::arg(),
|
||||
"stride"_a = 1,
|
||||
"padding"_a = 0,
|
||||
"dilation"_a = 1,
|
||||
"groups"_a = 1,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def conv_transpose1d(input: array, weight: array, /, stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
1D transposed convolution over an input with several channels
|
||||
|
||||
Args:
|
||||
input (array): Input array of shape ``(N, H, C_in)``.
|
||||
weight (array): Weight array of shape ``(C_out, H, C_in)``.
|
||||
stride (int, optional): Kernel stride. Default: ``1``.
|
||||
padding (int, optional): Input padding. Default: ``0``.
|
||||
dilation (int, optional): Kernel dilation. Default: ``1``.
|
||||
groups (int, optional): Input feature groups. Default: ``1``.
|
||||
|
||||
Returns:
|
||||
array: The convolved array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"conv_transpose2d",
|
||||
[](const array& input,
|
||||
const array& weight,
|
||||
const std::variant<int, std::pair<int, int>>& stride,
|
||||
const std::variant<int, std::pair<int, int>>& padding,
|
||||
const std::variant<int, std::pair<int, int>>& dilation,
|
||||
int groups,
|
||||
StreamOrDevice s) {
|
||||
std::pair<int, int> stride_pair{1, 1};
|
||||
std::pair<int, int> padding_pair{0, 0};
|
||||
std::pair<int, int> dilation_pair{1, 1};
|
||||
|
||||
if (auto pv = std::get_if<int>(&stride); pv) {
|
||||
stride_pair = std::pair<int, int>{*pv, *pv};
|
||||
} else {
|
||||
stride_pair = std::get<std::pair<int, int>>(stride);
|
||||
}
|
||||
|
||||
if (auto pv = std::get_if<int>(&padding); pv) {
|
||||
padding_pair = std::pair<int, int>{*pv, *pv};
|
||||
} else {
|
||||
padding_pair = std::get<std::pair<int, int>>(padding);
|
||||
}
|
||||
|
||||
if (auto pv = std::get_if<int>(&dilation); pv) {
|
||||
dilation_pair = std::pair<int, int>{*pv, *pv};
|
||||
} else {
|
||||
dilation_pair = std::get<std::pair<int, int>>(dilation);
|
||||
}
|
||||
|
||||
return conv_transpose2d(
|
||||
input, weight, stride_pair, padding_pair, dilation_pair, groups, s);
|
||||
},
|
||||
nb::arg(),
|
||||
nb::arg(),
|
||||
"stride"_a = 1,
|
||||
"padding"_a = 0,
|
||||
"dilation"_a = 1,
|
||||
"groups"_a = 1,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def conv_transpose2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
2D transposed convolution over an input with several channels
|
||||
|
||||
Note: Only the default ``groups=1`` is currently supported.
|
||||
|
||||
Args:
|
||||
input (array): Input array of shape ``(N, H, W, C_in)``.
|
||||
weight (array): Weight array of shape ``(C_out, H, W, C_in)``.
|
||||
stride (int or tuple(int), optional): :obj:`tuple` of size 2 with
|
||||
kernel strides. All spatial dimensions get the same stride if
|
||||
only one number is specified. Default: ``1``.
|
||||
padding (int or tuple(int), optional): :obj:`tuple` of size 2 with
|
||||
symmetric input padding. All spatial dimensions get the same
|
||||
padding if only one number is specified. Default: ``0``.
|
||||
dilation (int or tuple(int), optional): :obj:`tuple` of size 2 with
|
||||
kernel dilation. All spatial dimensions get the same dilation
|
||||
if only one number is specified. Default: ``1``
|
||||
groups (int, optional): input feature groups. Default: ``1``.
|
||||
|
||||
Returns:
|
||||
array: The convolved array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"conv_transpose3d",
|
||||
[](const array& input,
|
||||
const array& weight,
|
||||
const std::variant<int, std::tuple<int, int, int>>& stride,
|
||||
const std::variant<int, std::tuple<int, int, int>>& padding,
|
||||
const std::variant<int, std::tuple<int, int, int>>& dilation,
|
||||
int groups,
|
||||
StreamOrDevice s) {
|
||||
std::tuple<int, int, int> stride_tuple{1, 1, 1};
|
||||
std::tuple<int, int, int> padding_tuple{0, 0, 0};
|
||||
std::tuple<int, int, int> dilation_tuple{1, 1, 1};
|
||||
|
||||
if (auto pv = std::get_if<int>(&stride); pv) {
|
||||
stride_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};
|
||||
} else {
|
||||
stride_tuple = std::get<std::tuple<int, int, int>>(stride);
|
||||
}
|
||||
|
||||
if (auto pv = std::get_if<int>(&padding); pv) {
|
||||
padding_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};
|
||||
} else {
|
||||
padding_tuple = std::get<std::tuple<int, int, int>>(padding);
|
||||
}
|
||||
|
||||
if (auto pv = std::get_if<int>(&dilation); pv) {
|
||||
dilation_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};
|
||||
} else {
|
||||
dilation_tuple = std::get<std::tuple<int, int, int>>(dilation);
|
||||
}
|
||||
|
||||
return conv_transpose3d(
|
||||
input,
|
||||
weight,
|
||||
stride_tuple,
|
||||
padding_tuple,
|
||||
dilation_tuple,
|
||||
groups,
|
||||
s);
|
||||
},
|
||||
nb::arg(),
|
||||
nb::arg(),
|
||||
"stride"_a = 1,
|
||||
"padding"_a = 0,
|
||||
"dilation"_a = 1,
|
||||
"groups"_a = 1,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def conv_transpose3d(input: array, weight: array, /, stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, dilation: Union[int, Tuple[int, int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
3D transposed convolution over an input with several channels
|
||||
|
||||
Note: Only the default ``groups=1`` is currently supported.
|
||||
|
||||
Args:
|
||||
input (array): Input array of shape ``(N, D, H, W, C_in)``.
|
||||
weight (array): Weight array of shape ``(C_out, D, H, W, C_in)``.
|
||||
stride (int or tuple(int), optional): :obj:`tuple` of size 3 with
|
||||
kernel strides. All spatial dimensions get the same stride if
|
||||
only one number is specified. Default: ``1``.
|
||||
@@ -3465,8 +3630,8 @@ void init_ops(nb::module_& m) {
|
||||
General convolution over an input with several channels
|
||||
|
||||
Args:
|
||||
input (array): Input array of shape ``(N, ..., C_in)``
|
||||
weight (array): Weight array of shape ``(C_out, ..., C_in)``
|
||||
input (array): Input array of shape ``(N, ..., C_in)``.
|
||||
weight (array): Weight array of shape ``(C_out, ..., C_in)``.
|
||||
stride (int or list(int), optional): :obj:`list` with kernel strides.
|
||||
All spatial dimensions get the same stride if
|
||||
only one number is specified. Default: ``1``.
|
||||
|
||||
Reference in New Issue
Block a user