mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Fix conv_general and docstrings
This commit is contained in:
parent
f2b5ba49af
commit
7a41a7051b
@ -25,8 +25,8 @@ class ConvTranspose1d(Module):
|
|||||||
padding (int, optional): How many positions to 0-pad the input with.
|
padding (int, optional): How many positions to 0-pad the input with.
|
||||||
Default: ``0``.
|
Default: ``0``.
|
||||||
dilation (int, optional): The dilation of the convolution.
|
dilation (int, optional): The dilation of the convolution.
|
||||||
output_padding(int, optional): Additional size added to one side of the output shape.
|
output_padding(int, optional): Additional size added to one side of the
|
||||||
Default: ``0``.
|
output shape. Default: ``0``.
|
||||||
bias (bool, optional): If ``True`` add a learnable bias to the output.
|
bias (bool, optional): If ``True`` add a learnable bias to the output.
|
||||||
Default: ``True``
|
Default: ``True``
|
||||||
"""
|
"""
|
||||||
@ -100,8 +100,8 @@ class ConvTranspose2d(Module):
|
|||||||
padding (int or tuple, optional): How many positions to 0-pad
|
padding (int or tuple, optional): How many positions to 0-pad
|
||||||
the input with. Default: ``0``.
|
the input with. Default: ``0``.
|
||||||
dilation (int or tuple, optional): The dilation of the convolution.
|
dilation (int or tuple, optional): The dilation of the convolution.
|
||||||
output_padding(int or tuple, optional): Additional size added to one side of the output shape.
|
output_padding(int or tuple, optional): Additional size added to one
|
||||||
Default: ``0``.
|
side of the output shape. Default: ``0``.
|
||||||
bias (bool, optional): If ``True`` add a learnable bias to the
|
bias (bool, optional): If ``True`` add a learnable bias to the
|
||||||
output. Default: ``True``
|
output. Default: ``True``
|
||||||
"""
|
"""
|
||||||
@ -180,8 +180,8 @@ class ConvTranspose3d(Module):
|
|||||||
padding (int or tuple, optional): How many positions to 0-pad
|
padding (int or tuple, optional): How many positions to 0-pad
|
||||||
the input with. Default: ``0``.
|
the input with. Default: ``0``.
|
||||||
dilation (int or tuple, optional): The dilation of the convolution.
|
dilation (int or tuple, optional): The dilation of the convolution.
|
||||||
output_padding(int or tuple, optional): Additional size added to one side of the output shape.
|
output_padding(int or tuple, optional): Additional size added to one
|
||||||
Default: ``0``.
|
side of the output shape. Default: ``0``.
|
||||||
bias (bool, optional): If ``True`` add a learnable bias to the
|
bias (bool, optional): If ``True`` add a learnable bias to the
|
||||||
output. Default: ``True``
|
output. Default: ``True``
|
||||||
"""
|
"""
|
||||||
|
@ -3614,7 +3614,7 @@ void init_ops(nb::module_& m) {
|
|||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
"def conv_transpose1d(input: array, weight: array, /, stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
|
"def conv_transpose1d(input: array, weight: array, /, stride: int = 1, padding: int = 0, dilation: int = 1, output_padding: int = 0, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
1D transposed convolution over an input with several channels
|
1D transposed convolution over an input with several channels
|
||||||
|
|
||||||
@ -3689,7 +3689,7 @@ void init_ops(nb::module_& m) {
|
|||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
"def conv_transpose2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
|
"def conv_transpose2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, output_padding: Union[int, Tuple[int, int]] = 0, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
2D transposed convolution over an input with several channels
|
2D transposed convolution over an input with several channels
|
||||||
|
|
||||||
@ -3775,7 +3775,7 @@ void init_ops(nb::module_& m) {
|
|||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
"def conv_transpose3d(input: array, weight: array, /, stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, dilation: Union[int, Tuple[int, int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
|
"def conv_transpose3d(input: array, weight: array, /, stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, dilation: Union[int, Tuple[int, int, int]] = 1, output_padding: Union[int, Tuple[int, int, int]] = 0, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
3D transposed convolution over an input with several channels
|
3D transposed convolution over an input with several channels
|
||||||
|
|
||||||
@ -3794,7 +3794,7 @@ void init_ops(nb::module_& m) {
|
|||||||
kernel dilation. All spatial dimensions get the same dilation
|
kernel dilation. All spatial dimensions get the same dilation
|
||||||
if only one number is specified. Default: ``1``
|
if only one number is specified. Default: ``1``
|
||||||
output_padding (int or tuple(int), optional): :obj:`tuple` of size 3 with
|
output_padding (int or tuple(int), optional): :obj:`tuple` of size 3 with
|
||||||
output padding. All spatial dimensions get the same output
|
output padding. All spatial dimensions get the same output
|
||||||
padding if only one number is specified. Default: ``0``.
|
padding if only one number is specified. Default: ``0``.
|
||||||
groups (int, optional): input feature groups. Default: ``1``.
|
groups (int, optional): input feature groups. Default: ``1``.
|
||||||
|
|
||||||
@ -3812,7 +3812,6 @@ void init_ops(nb::module_& m) {
|
|||||||
std::pair<std::vector<int>, std::vector<int>>>& padding,
|
std::pair<std::vector<int>, std::vector<int>>>& padding,
|
||||||
const std::variant<int, std::vector<int>>& kernel_dilation,
|
const std::variant<int, std::vector<int>>& kernel_dilation,
|
||||||
const std::variant<int, std::vector<int>>& input_dilation,
|
const std::variant<int, std::vector<int>>& input_dilation,
|
||||||
const std::variant<int, std::vector<int>>& output_padding,
|
|
||||||
int groups,
|
int groups,
|
||||||
bool flip,
|
bool flip,
|
||||||
mx::StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
@ -3821,7 +3820,6 @@ void init_ops(nb::module_& m) {
|
|||||||
std::vector<int> padding_hi_vec;
|
std::vector<int> padding_hi_vec;
|
||||||
std::vector<int> kernel_dilation_vec;
|
std::vector<int> kernel_dilation_vec;
|
||||||
std::vector<int> input_dilation_vec;
|
std::vector<int> input_dilation_vec;
|
||||||
std::vector<int> output_padding_vec;
|
|
||||||
|
|
||||||
if (auto pv = std::get_if<int>(&stride); pv) {
|
if (auto pv = std::get_if<int>(&stride); pv) {
|
||||||
stride_vec.push_back(*pv);
|
stride_vec.push_back(*pv);
|
||||||
@ -3854,12 +3852,6 @@ void init_ops(nb::module_& m) {
|
|||||||
input_dilation_vec = std::get<std::vector<int>>(input_dilation);
|
input_dilation_vec = std::get<std::vector<int>>(input_dilation);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto pv = std::get_if<int>(&output_padding); pv) {
|
|
||||||
output_padding_vec.push_back(*pv);
|
|
||||||
} else {
|
|
||||||
output_padding_vec = std::get<std::vector<int>>(output_padding);
|
|
||||||
}
|
|
||||||
|
|
||||||
return mx::conv_general(
|
return mx::conv_general(
|
||||||
/* array input = */ std::move(input),
|
/* array input = */ std::move(input),
|
||||||
/* array weight = */ std::move(weight),
|
/* array weight = */ std::move(weight),
|
||||||
@ -3870,8 +3862,6 @@ void init_ops(nb::module_& m) {
|
|||||||
std::move(kernel_dilation_vec),
|
std::move(kernel_dilation_vec),
|
||||||
/* std::vector<int> input_dilation = */
|
/* std::vector<int> input_dilation = */
|
||||||
std::move(input_dilation_vec),
|
std::move(input_dilation_vec),
|
||||||
/* std::vector<int> output_padding = */
|
|
||||||
std::move(output_padding_vec),
|
|
||||||
/* int groups = */ groups,
|
/* int groups = */ groups,
|
||||||
/* bool flip = */ flip,
|
/* bool flip = */ flip,
|
||||||
s);
|
s);
|
||||||
@ -3882,7 +3872,6 @@ void init_ops(nb::module_& m) {
|
|||||||
"padding"_a = 0,
|
"padding"_a = 0,
|
||||||
"kernel_dilation"_a = 1,
|
"kernel_dilation"_a = 1,
|
||||||
"input_dilation"_a = 1,
|
"input_dilation"_a = 1,
|
||||||
"output_padding"_a = 0,
|
|
||||||
"groups"_a = 1,
|
"groups"_a = 1,
|
||||||
"flip"_a = false,
|
"flip"_a = false,
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
@ -3907,9 +3896,6 @@ void init_ops(nb::module_& m) {
|
|||||||
input_dilation (int or list(int), optional): :obj:`list` with
|
input_dilation (int or list(int), optional): :obj:`list` with
|
||||||
input dilation. All spatial dimensions get the same dilation
|
input dilation. All spatial dimensions get the same dilation
|
||||||
if only one number is specified. Default: ``1``
|
if only one number is specified. Default: ``1``
|
||||||
output_padding (int or list(int), optional): :obj:`list` with
|
|
||||||
output padding. All spatial dimensions get the same padding
|
|
||||||
if only one number is specified. Default: ``0``
|
|
||||||
groups (int, optional): Input feature groups. Default: ``1``.
|
groups (int, optional): Input feature groups. Default: ``1``.
|
||||||
flip (bool, optional): Flip the order in which the spatial dimensions of
|
flip (bool, optional): Flip the order in which the spatial dimensions of
|
||||||
the weights are processed. Performs the cross-correlation operator when
|
the weights are processed. Performs the cross-correlation operator when
|
||||||
|
Loading…
Reference in New Issue
Block a user