From 7a41a7051b4ebadc40abc8ff679e8f693dd9f18f Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 22 Apr 2025 15:41:41 -0700 Subject: [PATCH] Fix conv_general and docstrings --- python/mlx/nn/layers/convolution_transpose.py | 12 +++++----- python/src/ops.cpp | 22 ++++--------------- 2 files changed, 10 insertions(+), 24 deletions(-) diff --git a/python/mlx/nn/layers/convolution_transpose.py b/python/mlx/nn/layers/convolution_transpose.py index ff321ac98..a11c4cb40 100644 --- a/python/mlx/nn/layers/convolution_transpose.py +++ b/python/mlx/nn/layers/convolution_transpose.py @@ -25,8 +25,8 @@ class ConvTranspose1d(Module): padding (int, optional): How many positions to 0-pad the input with. Default: ``0``. dilation (int, optional): The dilation of the convolution. - output_padding(int, optional): Additional size added to one side of the output shape. - Default: ``0``. + output_padding(int, optional): Additional size added to one side of the + output shape. Default: ``0``. bias (bool, optional): If ``True`` add a learnable bias to the output. Default: ``True`` """ @@ -100,8 +100,8 @@ class ConvTranspose2d(Module): padding (int or tuple, optional): How many positions to 0-pad the input with. Default: ``0``. dilation (int or tuple, optional): The dilation of the convolution. - output_padding(int or tuple, optional): Additional size added to one side of the output shape. - Default: ``0``. + output_padding(int or tuple, optional): Additional size added to one + side of the output shape. Default: ``0``. bias (bool, optional): If ``True`` add a learnable bias to the output. Default: ``True`` """ @@ -180,8 +180,8 @@ class ConvTranspose3d(Module): padding (int or tuple, optional): How many positions to 0-pad the input with. Default: ``0``. dilation (int or tuple, optional): The dilation of the convolution. - output_padding(int or tuple, optional): Additional size added to one side of the output shape. - Default: ``0``. + output_padding(int or tuple, optional): Additional size added to one + side of the output shape. Default: ``0``. bias (bool, optional): If ``True`` add a learnable bias to the output. Default: ``True`` """ diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 8c5cea9f9..f6dd11da3 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3614,7 +3614,7 @@ void init_ops(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def conv_transpose1d(input: array, weight: array, /, stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"), + "def conv_transpose1d(input: array, weight: array, /, stride: int = 1, padding: int = 0, dilation: int = 1, output_padding: int = 0, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( 1D transposed convolution over an input with several channels @@ -3689,7 +3689,7 @@ void init_ops(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def conv_transpose2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"), + "def conv_transpose2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, output_padding: Union[int, Tuple[int, int]] = 0, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( 2D transposed convolution over an input with several channels @@ -3775,7 +3775,7 @@ void init_ops(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def conv_transpose3d(input: array, weight: array, /, stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, dilation: Union[int, Tuple[int, int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"), + "def conv_transpose3d(input: array, weight: array, /, stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, dilation: Union[int, Tuple[int, int, int]] = 1, output_padding: Union[int, Tuple[int, int, int]] = 0, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( 3D transposed convolution over an input with several channels @@ -3794,7 +3794,7 @@ void init_ops(nb::module_& m) { kernel dilation. All spatial dimensions get the same dilation if only one number is specified. Default: ``1`` output_padding (int or tuple(int), optional): :obj:`tuple` of size 3 with - output padding. All spatial dimensions get the same output + output padding. All spatial dimensions get the same output padding if only one number is specified. Default: ``0``. groups (int, optional): input feature groups. Default: ``1``. @@ -3812,7 +3812,6 @@ void init_ops(nb::module_& m) { std::pair, std::vector>>& padding, const std::variant>& kernel_dilation, const std::variant>& input_dilation, - const std::variant>& output_padding, int groups, bool flip, mx::StreamOrDevice s) { @@ -3821,7 +3820,6 @@ void init_ops(nb::module_& m) { std::vector padding_hi_vec; std::vector kernel_dilation_vec; std::vector input_dilation_vec; - std::vector output_padding_vec; if (auto pv = std::get_if(&stride); pv) { stride_vec.push_back(*pv); @@ -3854,12 +3852,6 @@ void init_ops(nb::module_& m) { input_dilation_vec = std::get>(input_dilation); } - if (auto pv = std::get_if(&output_padding); pv) { - output_padding_vec.push_back(*pv); - } else { - output_padding_vec = std::get>(output_padding); - } - return mx::conv_general( /* array input = */ std::move(input), /* array weight = */ std::move(weight), @@ -3870,8 +3862,6 @@ void init_ops(nb::module_& m) { std::move(kernel_dilation_vec), /* std::vector input_dilation = */ std::move(input_dilation_vec), - /* std::vector output_padding = */ - std::move(output_padding_vec), /* int groups = */ groups, /* bool flip = */ flip, s); @@ -3882,7 +3872,6 @@ void init_ops(nb::module_& m) { "padding"_a = 0, "kernel_dilation"_a = 1, "input_dilation"_a = 1, - "output_padding"_a = 0, "groups"_a = 1, "flip"_a = false, nb::kw_only(), @@ -3907,9 +3896,6 @@ void init_ops(nb::module_& m) { input_dilation (int or list(int), optional): :obj:`list` with input dilation. All spatial dimensions get the same dilation if only one number is specified. Default: ``1`` - output_padding (int or list(int), optional): :obj:`list` with - output padding. All spatial dimensions get the same padding - if only one number is specified. Default: ``0`` groups (int, optional): Input feature groups. Default: ``1``. flip (bool, optional): Flip the order in which the spatial dimensions of the weights are processed. Performs the cross-correlation operator when