Fix conv_general and docstrings

This commit is contained in:
Angelos Katharopoulos 2025-04-22 15:41:41 -07:00
parent f2b5ba49af
commit 7a41a7051b
2 changed files with 10 additions and 24 deletions

View File

@ -25,8 +25,8 @@ class ConvTranspose1d(Module):
padding (int, optional): How many positions to 0-pad the input with.
Default: ``0``.
dilation (int, optional): The dilation of the convolution.
output_padding(int, optional): Additional size added to one side of the output shape.
Default: ``0``.
output_padding(int, optional): Additional size added to one side of the
output shape. Default: ``0``.
bias (bool, optional): If ``True`` add a learnable bias to the output.
Default: ``True``
"""
@ -100,8 +100,8 @@ class ConvTranspose2d(Module):
padding (int or tuple, optional): How many positions to 0-pad
the input with. Default: ``0``.
dilation (int or tuple, optional): The dilation of the convolution.
output_padding(int or tuple, optional): Additional size added to one side of the output shape.
Default: ``0``.
output_padding(int or tuple, optional): Additional size added to one
side of the output shape. Default: ``0``.
bias (bool, optional): If ``True`` add a learnable bias to the
output. Default: ``True``
"""
@ -180,8 +180,8 @@ class ConvTranspose3d(Module):
padding (int or tuple, optional): How many positions to 0-pad
the input with. Default: ``0``.
dilation (int or tuple, optional): The dilation of the convolution.
output_padding(int or tuple, optional): Additional size added to one side of the output shape.
Default: ``0``.
output_padding(int or tuple, optional): Additional size added to one
side of the output shape. Default: ``0``.
bias (bool, optional): If ``True`` add a learnable bias to the
output. Default: ``True``
"""

View File

@ -3614,7 +3614,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def conv_transpose1d(input: array, weight: array, /, stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
"def conv_transpose1d(input: array, weight: array, /, stride: int = 1, padding: int = 0, dilation: int = 1, output_padding: int = 0, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
1D transposed convolution over an input with several channels
@ -3689,7 +3689,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def conv_transpose2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
"def conv_transpose2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, output_padding: Union[int, Tuple[int, int]] = 0, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
2D transposed convolution over an input with several channels
@ -3775,7 +3775,7 @@ void init_ops(nb::module_& m) {
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def conv_transpose3d(input: array, weight: array, /, stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, dilation: Union[int, Tuple[int, int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
"def conv_transpose3d(input: array, weight: array, /, stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, dilation: Union[int, Tuple[int, int, int]] = 1, output_padding: Union[int, Tuple[int, int, int]] = 0, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
3D transposed convolution over an input with several channels
@ -3794,7 +3794,7 @@ void init_ops(nb::module_& m) {
kernel dilation. All spatial dimensions get the same dilation
if only one number is specified. Default: ``1``
output_padding (int or tuple(int), optional): :obj:`tuple` of size 3 with
output padding. All spatial dimensions get the same output
output padding. All spatial dimensions get the same output
padding if only one number is specified. Default: ``0``.
groups (int, optional): input feature groups. Default: ``1``.
@ -3812,7 +3812,6 @@ void init_ops(nb::module_& m) {
std::pair<std::vector<int>, std::vector<int>>>& padding,
const std::variant<int, std::vector<int>>& kernel_dilation,
const std::variant<int, std::vector<int>>& input_dilation,
const std::variant<int, std::vector<int>>& output_padding,
int groups,
bool flip,
mx::StreamOrDevice s) {
@ -3821,7 +3820,6 @@ void init_ops(nb::module_& m) {
std::vector<int> padding_hi_vec;
std::vector<int> kernel_dilation_vec;
std::vector<int> input_dilation_vec;
std::vector<int> output_padding_vec;
if (auto pv = std::get_if<int>(&stride); pv) {
stride_vec.push_back(*pv);
@ -3854,12 +3852,6 @@ void init_ops(nb::module_& m) {
input_dilation_vec = std::get<std::vector<int>>(input_dilation);
}
if (auto pv = std::get_if<int>(&output_padding); pv) {
output_padding_vec.push_back(*pv);
} else {
output_padding_vec = std::get<std::vector<int>>(output_padding);
}
return mx::conv_general(
/* array input = */ std::move(input),
/* array weight = */ std::move(weight),
@ -3870,8 +3862,6 @@ void init_ops(nb::module_& m) {
std::move(kernel_dilation_vec),
/* std::vector<int> input_dilation = */
std::move(input_dilation_vec),
/* std::vector<int> output_padding = */
std::move(output_padding_vec),
/* int groups = */ groups,
/* bool flip = */ flip,
s);
@ -3882,7 +3872,6 @@ void init_ops(nb::module_& m) {
"padding"_a = 0,
"kernel_dilation"_a = 1,
"input_dilation"_a = 1,
"output_padding"_a = 0,
"groups"_a = 1,
"flip"_a = false,
nb::kw_only(),
@ -3907,9 +3896,6 @@ void init_ops(nb::module_& m) {
input_dilation (int or list(int), optional): :obj:`list` with
input dilation. All spatial dimensions get the same dilation
if only one number is specified. Default: ``1``
output_padding (int or list(int), optional): :obj:`list` with
output padding. All spatial dimensions get the same padding
if only one number is specified. Default: ``0``
groups (int, optional): Input feature groups. Default: ``1``.
flip (bool, optional): Flip the order in which the spatial dimensions of
the weights are processed. Performs the cross-correlation operator when