mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Fix conv_general and docstrings
This commit is contained in:
parent
f2b5ba49af
commit
7a41a7051b
@ -25,8 +25,8 @@ class ConvTranspose1d(Module):
|
||||
padding (int, optional): How many positions to 0-pad the input with.
|
||||
Default: ``0``.
|
||||
dilation (int, optional): The dilation of the convolution.
|
||||
output_padding(int, optional): Additional size added to one side of the output shape.
|
||||
Default: ``0``.
|
||||
output_padding(int, optional): Additional size added to one side of the
|
||||
output shape. Default: ``0``.
|
||||
bias (bool, optional): If ``True`` add a learnable bias to the output.
|
||||
Default: ``True``
|
||||
"""
|
||||
@ -100,8 +100,8 @@ class ConvTranspose2d(Module):
|
||||
padding (int or tuple, optional): How many positions to 0-pad
|
||||
the input with. Default: ``0``.
|
||||
dilation (int or tuple, optional): The dilation of the convolution.
|
||||
output_padding(int or tuple, optional): Additional size added to one side of the output shape.
|
||||
Default: ``0``.
|
||||
output_padding(int or tuple, optional): Additional size added to one
|
||||
side of the output shape. Default: ``0``.
|
||||
bias (bool, optional): If ``True`` add a learnable bias to the
|
||||
output. Default: ``True``
|
||||
"""
|
||||
@ -180,8 +180,8 @@ class ConvTranspose3d(Module):
|
||||
padding (int or tuple, optional): How many positions to 0-pad
|
||||
the input with. Default: ``0``.
|
||||
dilation (int or tuple, optional): The dilation of the convolution.
|
||||
output_padding(int or tuple, optional): Additional size added to one side of the output shape.
|
||||
Default: ``0``.
|
||||
output_padding(int or tuple, optional): Additional size added to one
|
||||
side of the output shape. Default: ``0``.
|
||||
bias (bool, optional): If ``True`` add a learnable bias to the
|
||||
output. Default: ``True``
|
||||
"""
|
||||
|
@ -3614,7 +3614,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def conv_transpose1d(input: array, weight: array, /, stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def conv_transpose1d(input: array, weight: array, /, stride: int = 1, padding: int = 0, dilation: int = 1, output_padding: int = 0, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
1D transposed convolution over an input with several channels
|
||||
|
||||
@ -3689,7 +3689,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def conv_transpose2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def conv_transpose2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, output_padding: Union[int, Tuple[int, int]] = 0, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
2D transposed convolution over an input with several channels
|
||||
|
||||
@ -3775,7 +3775,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def conv_transpose3d(input: array, weight: array, /, stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, dilation: Union[int, Tuple[int, int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def conv_transpose3d(input: array, weight: array, /, stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, dilation: Union[int, Tuple[int, int, int]] = 1, output_padding: Union[int, Tuple[int, int, int]] = 0, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
3D transposed convolution over an input with several channels
|
||||
|
||||
@ -3812,7 +3812,6 @@ void init_ops(nb::module_& m) {
|
||||
std::pair<std::vector<int>, std::vector<int>>>& padding,
|
||||
const std::variant<int, std::vector<int>>& kernel_dilation,
|
||||
const std::variant<int, std::vector<int>>& input_dilation,
|
||||
const std::variant<int, std::vector<int>>& output_padding,
|
||||
int groups,
|
||||
bool flip,
|
||||
mx::StreamOrDevice s) {
|
||||
@ -3821,7 +3820,6 @@ void init_ops(nb::module_& m) {
|
||||
std::vector<int> padding_hi_vec;
|
||||
std::vector<int> kernel_dilation_vec;
|
||||
std::vector<int> input_dilation_vec;
|
||||
std::vector<int> output_padding_vec;
|
||||
|
||||
if (auto pv = std::get_if<int>(&stride); pv) {
|
||||
stride_vec.push_back(*pv);
|
||||
@ -3854,12 +3852,6 @@ void init_ops(nb::module_& m) {
|
||||
input_dilation_vec = std::get<std::vector<int>>(input_dilation);
|
||||
}
|
||||
|
||||
if (auto pv = std::get_if<int>(&output_padding); pv) {
|
||||
output_padding_vec.push_back(*pv);
|
||||
} else {
|
||||
output_padding_vec = std::get<std::vector<int>>(output_padding);
|
||||
}
|
||||
|
||||
return mx::conv_general(
|
||||
/* array input = */ std::move(input),
|
||||
/* array weight = */ std::move(weight),
|
||||
@ -3870,8 +3862,6 @@ void init_ops(nb::module_& m) {
|
||||
std::move(kernel_dilation_vec),
|
||||
/* std::vector<int> input_dilation = */
|
||||
std::move(input_dilation_vec),
|
||||
/* std::vector<int> output_padding = */
|
||||
std::move(output_padding_vec),
|
||||
/* int groups = */ groups,
|
||||
/* bool flip = */ flip,
|
||||
s);
|
||||
@ -3882,7 +3872,6 @@ void init_ops(nb::module_& m) {
|
||||
"padding"_a = 0,
|
||||
"kernel_dilation"_a = 1,
|
||||
"input_dilation"_a = 1,
|
||||
"output_padding"_a = 0,
|
||||
"groups"_a = 1,
|
||||
"flip"_a = false,
|
||||
nb::kw_only(),
|
||||
@ -3907,9 +3896,6 @@ void init_ops(nb::module_& m) {
|
||||
input_dilation (int or list(int), optional): :obj:`list` with
|
||||
input dilation. All spatial dimensions get the same dilation
|
||||
if only one number is specified. Default: ``1``
|
||||
output_padding (int or list(int), optional): :obj:`list` with
|
||||
output padding. All spatial dimensions get the same padding
|
||||
if only one number is specified. Default: ``0``
|
||||
groups (int, optional): Input feature groups. Default: ``1``.
|
||||
flip (bool, optional): Flip the order in which the spatial dimensions of
|
||||
the weights are processed. Performs the cross-correlation operator when
|
||||
|
Loading…
Reference in New Issue
Block a user