mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Transposed Convolution (#1245)
* initial implementation for conv_transpose ran pre-commit implemented conv_transpose updated conv_general docstring updated conv_general docstring updated code comments removed commented run_conv_checks updated acknowledgments added missing entry to ops.rst added op to nn.layers resolved merge conflicts * removed ConvolutionTranspose primitive as suggested by reviewer removed ConvolutionTranspose primitive as suggested by reviewer * remove transpose flag, add another test --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
		
				
					committed by
					
						
						GitHub
					
				
			
			
				
	
			
			
			
						parent
						
							ba3e913c7a
						
					
				
				
					commit
					efeb9c0f02
				
			@@ -3238,12 +3238,12 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
        1D convolution over an input with several channels
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            input (array): input array of shape (``N``, ``H``, ``C_in``)
 | 
			
		||||
            weight (array): weight array of shape (``C_out``, ``H``, ``C_in``)
 | 
			
		||||
            stride (int, optional): kernel stride. Default: ``1``.
 | 
			
		||||
            padding (int, optional): input padding. Default: ``0``.
 | 
			
		||||
            dilation (int, optional): kernel dilation. Default: ``1``.
 | 
			
		||||
            groups (int, optional): input feature groups. Default: ``1``.
 | 
			
		||||
            input (array): Input array of shape ``(N, H, C_in)``.
 | 
			
		||||
            weight (array): Weight array of shape ``(C_out, H, C_in)``.
 | 
			
		||||
            stride (int, optional): Kernel stride. Default: ``1``.
 | 
			
		||||
            padding (int, optional): Input padding. Default: ``0``.
 | 
			
		||||
            dilation (int, optional): Kernel dilation. Default: ``1``.
 | 
			
		||||
            groups (int, optional): Input feature groups. Default: ``1``.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            array: The convolved array.
 | 
			
		||||
@@ -3296,8 +3296,8 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
        2D convolution over an input with several channels
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            input (array): input array of shape ``(N, H, W, C_in)``
 | 
			
		||||
            weight (array): weight array of shape ``(C_out, H, W, C_in)``
 | 
			
		||||
            input (array): Input array of shape ``(N, H, W, C_in)``.
 | 
			
		||||
            weight (array): Weight array of shape ``(C_out, H, W, C_in)``.
 | 
			
		||||
            stride (int or tuple(int), optional): :obj:`tuple` of size 2 with
 | 
			
		||||
                kernel strides. All spatial dimensions get the same stride if
 | 
			
		||||
                only one number is specified. Default: ``1``.
 | 
			
		||||
@@ -3368,8 +3368,173 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
        Note: Only the default ``groups=1`` is currently supported.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            input (array): input array of shape ``(N, D, H, W, C_in)``
 | 
			
		||||
            weight (array): weight array of shape ``(C_out, D, H, W, C_in)``
 | 
			
		||||
            input (array): Input array of shape ``(N, D, H, W, C_in)``.
 | 
			
		||||
            weight (array): Weight array of shape ``(C_out, D, H, W, C_in)``.
 | 
			
		||||
            stride (int or tuple(int), optional): :obj:`tuple` of size 3 with
 | 
			
		||||
                kernel strides. All spatial dimensions get the same stride if
 | 
			
		||||
                only one number is specified. Default: ``1``.
 | 
			
		||||
            padding (int or tuple(int), optional): :obj:`tuple` of size 3 with
 | 
			
		||||
                symmetric input padding. All spatial dimensions get the same
 | 
			
		||||
                padding if only one number is specified. Default: ``0``.
 | 
			
		||||
            dilation (int or tuple(int), optional): :obj:`tuple` of size 3 with
 | 
			
		||||
                kernel dilation. All spatial dimensions get the same dilation
 | 
			
		||||
                if only one number is specified. Default: ``1``
 | 
			
		||||
            groups (int, optional): input feature groups. Default: ``1``.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            array: The convolved array.
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "conv_transpose1d",
 | 
			
		||||
      &conv_transpose1d,
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      "stride"_a = 1,
 | 
			
		||||
      "padding"_a = 0,
 | 
			
		||||
      "dilation"_a = 1,
 | 
			
		||||
      "groups"_a = 1,
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
      nb::sig(
 | 
			
		||||
          "def conv_transpose1d(input: array, weight: array, /, stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        1D transposed convolution over an input with several channels
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            input (array): Input array of shape ``(N, H, C_in)``.
 | 
			
		||||
            weight (array): Weight array of shape ``(C_out, H, C_in)``.
 | 
			
		||||
            stride (int, optional): Kernel stride. Default: ``1``.
 | 
			
		||||
            padding (int, optional): Input padding. Default: ``0``.
 | 
			
		||||
            dilation (int, optional): Kernel dilation. Default: ``1``.
 | 
			
		||||
            groups (int, optional): Input feature groups. Default: ``1``.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            array: The convolved array.
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "conv_transpose2d",
 | 
			
		||||
      [](const array& input,
 | 
			
		||||
         const array& weight,
 | 
			
		||||
         const std::variant<int, std::pair<int, int>>& stride,
 | 
			
		||||
         const std::variant<int, std::pair<int, int>>& padding,
 | 
			
		||||
         const std::variant<int, std::pair<int, int>>& dilation,
 | 
			
		||||
         int groups,
 | 
			
		||||
         StreamOrDevice s) {
 | 
			
		||||
        std::pair<int, int> stride_pair{1, 1};
 | 
			
		||||
        std::pair<int, int> padding_pair{0, 0};
 | 
			
		||||
        std::pair<int, int> dilation_pair{1, 1};
 | 
			
		||||
 | 
			
		||||
        if (auto pv = std::get_if<int>(&stride); pv) {
 | 
			
		||||
          stride_pair = std::pair<int, int>{*pv, *pv};
 | 
			
		||||
        } else {
 | 
			
		||||
          stride_pair = std::get<std::pair<int, int>>(stride);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (auto pv = std::get_if<int>(&padding); pv) {
 | 
			
		||||
          padding_pair = std::pair<int, int>{*pv, *pv};
 | 
			
		||||
        } else {
 | 
			
		||||
          padding_pair = std::get<std::pair<int, int>>(padding);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (auto pv = std::get_if<int>(&dilation); pv) {
 | 
			
		||||
          dilation_pair = std::pair<int, int>{*pv, *pv};
 | 
			
		||||
        } else {
 | 
			
		||||
          dilation_pair = std::get<std::pair<int, int>>(dilation);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        return conv_transpose2d(
 | 
			
		||||
            input, weight, stride_pair, padding_pair, dilation_pair, groups, s);
 | 
			
		||||
      },
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      "stride"_a = 1,
 | 
			
		||||
      "padding"_a = 0,
 | 
			
		||||
      "dilation"_a = 1,
 | 
			
		||||
      "groups"_a = 1,
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
      nb::sig(
 | 
			
		||||
          "def conv_transpose2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        2D transposed convolution over an input with several channels
 | 
			
		||||
 | 
			
		||||
        Note: Only the default ``groups=1`` is currently supported.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            input (array): Input array of shape ``(N, H, W, C_in)``.
 | 
			
		||||
            weight (array): Weight array of shape ``(C_out, H, W, C_in)``.
 | 
			
		||||
            stride (int or tuple(int), optional): :obj:`tuple` of size 2 with
 | 
			
		||||
                kernel strides. All spatial dimensions get the same stride if
 | 
			
		||||
                only one number is specified. Default: ``1``.
 | 
			
		||||
            padding (int or tuple(int), optional): :obj:`tuple` of size 2 with
 | 
			
		||||
                symmetric input padding. All spatial dimensions get the same
 | 
			
		||||
                padding if only one number is specified. Default: ``0``.
 | 
			
		||||
            dilation (int or tuple(int), optional): :obj:`tuple` of size 2 with
 | 
			
		||||
                kernel dilation. All spatial dimensions get the same dilation
 | 
			
		||||
                if only one number is specified. Default: ``1``
 | 
			
		||||
            groups (int, optional): input feature groups. Default: ``1``.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            array: The convolved array.
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "conv_transpose3d",
 | 
			
		||||
      [](const array& input,
 | 
			
		||||
         const array& weight,
 | 
			
		||||
         const std::variant<int, std::tuple<int, int, int>>& stride,
 | 
			
		||||
         const std::variant<int, std::tuple<int, int, int>>& padding,
 | 
			
		||||
         const std::variant<int, std::tuple<int, int, int>>& dilation,
 | 
			
		||||
         int groups,
 | 
			
		||||
         StreamOrDevice s) {
 | 
			
		||||
        std::tuple<int, int, int> stride_tuple{1, 1, 1};
 | 
			
		||||
        std::tuple<int, int, int> padding_tuple{0, 0, 0};
 | 
			
		||||
        std::tuple<int, int, int> dilation_tuple{1, 1, 1};
 | 
			
		||||
 | 
			
		||||
        if (auto pv = std::get_if<int>(&stride); pv) {
 | 
			
		||||
          stride_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};
 | 
			
		||||
        } else {
 | 
			
		||||
          stride_tuple = std::get<std::tuple<int, int, int>>(stride);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (auto pv = std::get_if<int>(&padding); pv) {
 | 
			
		||||
          padding_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};
 | 
			
		||||
        } else {
 | 
			
		||||
          padding_tuple = std::get<std::tuple<int, int, int>>(padding);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (auto pv = std::get_if<int>(&dilation); pv) {
 | 
			
		||||
          dilation_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};
 | 
			
		||||
        } else {
 | 
			
		||||
          dilation_tuple = std::get<std::tuple<int, int, int>>(dilation);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        return conv_transpose3d(
 | 
			
		||||
            input,
 | 
			
		||||
            weight,
 | 
			
		||||
            stride_tuple,
 | 
			
		||||
            padding_tuple,
 | 
			
		||||
            dilation_tuple,
 | 
			
		||||
            groups,
 | 
			
		||||
            s);
 | 
			
		||||
      },
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      "stride"_a = 1,
 | 
			
		||||
      "padding"_a = 0,
 | 
			
		||||
      "dilation"_a = 1,
 | 
			
		||||
      "groups"_a = 1,
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
      nb::sig(
 | 
			
		||||
          "def conv_transpose3d(input: array, weight: array, /, stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, dilation: Union[int, Tuple[int, int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        3D transposed convolution over an input with several channels
 | 
			
		||||
 | 
			
		||||
        Note: Only the default ``groups=1`` is currently supported.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            input (array): Input array of shape ``(N, D, H, W, C_in)``.
 | 
			
		||||
            weight (array): Weight array of shape ``(C_out, D, H, W, C_in)``.
 | 
			
		||||
            stride (int or tuple(int), optional): :obj:`tuple` of size 3 with
 | 
			
		||||
                kernel strides. All spatial dimensions get the same stride if
 | 
			
		||||
                only one number is specified. Default: ``1``.
 | 
			
		||||
@@ -3465,8 +3630,8 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
        General convolution over an input with several channels
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            input (array): Input array of shape ``(N, ..., C_in)``
 | 
			
		||||
            weight (array): Weight array of shape ``(C_out, ..., C_in)``
 | 
			
		||||
            input (array): Input array of shape ``(N, ..., C_in)``.
 | 
			
		||||
            weight (array): Weight array of shape ``(C_out, ..., C_in)``.
 | 
			
		||||
            stride (int or list(int), optional): :obj:`list` with kernel strides.
 | 
			
		||||
                All spatial dimensions get the same stride if
 | 
			
		||||
                only one number is specified. Default: ``1``.
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user