mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Conv3d (#993)
* added conv3d added conv3d implemented explicit_gemm_conv_ND_cpu and bounds checks for slow_conv_3D * incorporated reviewer comments * fixed test * reduced tensor shapes in test for conv3d * Reviewer suggestion Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Reviewer suggestion Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Reviewer suggestion Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Reviewer suggestion
This commit is contained in:
		
				
					committed by
					
						
						GitHub
					
				
			
			
				
	
			
			
			
						parent
						
							a9f80d60f6
						
					
				
				
					commit
					ff4223904d
				
			@@ -3230,6 +3230,78 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
            array: The convolved array.
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "conv3d",
 | 
			
		||||
      [](const array& input,
 | 
			
		||||
         const array& weight,
 | 
			
		||||
         const std::variant<int, std::tuple<int, int, int>>& stride,
 | 
			
		||||
         const std::variant<int, std::tuple<int, int, int>>& padding,
 | 
			
		||||
         const std::variant<int, std::tuple<int, int, int>>& dilation,
 | 
			
		||||
         int groups,
 | 
			
		||||
         StreamOrDevice s) {
 | 
			
		||||
        std::tuple<int, int, int> stride_tuple{1, 1, 1};
 | 
			
		||||
        std::tuple<int, int, int> padding_tuple{0, 0, 0};
 | 
			
		||||
        std::tuple<int, int, int> dilation_tuple{1, 1, 1};
 | 
			
		||||
 | 
			
		||||
        if (auto pv = std::get_if<int>(&stride); pv) {
 | 
			
		||||
          stride_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};
 | 
			
		||||
        } else {
 | 
			
		||||
          stride_tuple = std::get<std::tuple<int, int, int>>(stride);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (auto pv = std::get_if<int>(&padding); pv) {
 | 
			
		||||
          padding_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};
 | 
			
		||||
        } else {
 | 
			
		||||
          padding_tuple = std::get<std::tuple<int, int, int>>(padding);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (auto pv = std::get_if<int>(&dilation); pv) {
 | 
			
		||||
          dilation_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};
 | 
			
		||||
        } else {
 | 
			
		||||
          dilation_tuple = std::get<std::tuple<int, int, int>>(dilation);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        return conv3d(
 | 
			
		||||
            input,
 | 
			
		||||
            weight,
 | 
			
		||||
            stride_tuple,
 | 
			
		||||
            padding_tuple,
 | 
			
		||||
            dilation_tuple,
 | 
			
		||||
            groups,
 | 
			
		||||
            s);
 | 
			
		||||
      },
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      "stride"_a = 1,
 | 
			
		||||
      "padding"_a = 0,
 | 
			
		||||
      "dilation"_a = 1,
 | 
			
		||||
      "groups"_a = 1,
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
      nb::sig(
 | 
			
		||||
          "def conv3d(input: array, weight: array, /, stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, dilation: Union[int, Tuple[int, int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        3D convolution over an input with several channels
 | 
			
		||||
 | 
			
		||||
        Note: Only the default ``groups=1`` is currently supported.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            input (array): input array of shape ``(N, D, H, W, C_in)``
 | 
			
		||||
            weight (array): weight array of shape ``(C_out, D, H, W, C_in)``
 | 
			
		||||
            stride (int or tuple(int), optional): :obj:`tuple` of size 3 with
 | 
			
		||||
                kernel strides. All spatial dimensions get the same stride if
 | 
			
		||||
                only one number is specified. Default: ``1``.
 | 
			
		||||
            padding (int or tuple(int), optional): :obj:`tuple` of size 3 with
 | 
			
		||||
                symmetric input padding. All spatial dimensions get the same
 | 
			
		||||
                padding if only one number is specified. Default: ``0``.
 | 
			
		||||
            dilation (int or tuple(int), optional): :obj:`tuple` of size 3 with
 | 
			
		||||
                kernel dilation. All spatial dimensions get the same dilation
 | 
			
		||||
                if only one number is specified. Default: ``1``
 | 
			
		||||
            groups (int, optional): input feature groups. Default: ``1``.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            array: The convolved array.
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "conv_general",
 | 
			
		||||
      [](const array& input,
 | 
			
		||||
         const array& weight,
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user