mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Add "edge" mode to mx.pad (#1309)
* Add edge padding mode * fix pad in pooling * string arg instead of enum
This commit is contained in:
		@@ -2843,30 +2843,32 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
             std::tuple<int>,
 | 
			
		||||
             std::pair<int, int>,
 | 
			
		||||
             std::vector<std::pair<int, int>>>& pad_width,
 | 
			
		||||
         const std::string mode,
 | 
			
		||||
         const ScalarOrArray& constant_value,
 | 
			
		||||
         StreamOrDevice s) {
 | 
			
		||||
        if (auto pv = std::get_if<int>(&pad_width); pv) {
 | 
			
		||||
          return pad(a, *pv, to_array(constant_value), s);
 | 
			
		||||
          return pad(a, *pv, to_array(constant_value), mode, s);
 | 
			
		||||
        } else if (auto pv = std::get_if<std::tuple<int>>(&pad_width); pv) {
 | 
			
		||||
          return pad(a, std::get<0>(*pv), to_array(constant_value), s);
 | 
			
		||||
          return pad(a, std::get<0>(*pv), to_array(constant_value), mode, s);
 | 
			
		||||
        } else if (auto pv = std::get_if<std::pair<int, int>>(&pad_width); pv) {
 | 
			
		||||
          return pad(a, *pv, to_array(constant_value), s);
 | 
			
		||||
          return pad(a, *pv, to_array(constant_value), mode, s);
 | 
			
		||||
        } else {
 | 
			
		||||
          auto v = std::get<std::vector<std::pair<int, int>>>(pad_width);
 | 
			
		||||
          if (v.size() == 1) {
 | 
			
		||||
            return pad(a, v[0], to_array(constant_value), s);
 | 
			
		||||
            return pad(a, v[0], to_array(constant_value), mode, s);
 | 
			
		||||
          } else {
 | 
			
		||||
            return pad(a, v, to_array(constant_value), s);
 | 
			
		||||
            return pad(a, v, to_array(constant_value), mode, s);
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
      },
 | 
			
		||||
      nb::arg(),
 | 
			
		||||
      "pad_width"_a,
 | 
			
		||||
      "mode"_a = "constant",
 | 
			
		||||
      "constant_values"_a = 0,
 | 
			
		||||
      nb::kw_only(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
      nb::sig(
 | 
			
		||||
          "def pad(a: array, pad_width: Union[int, Tuple[int], Tuple[int, int], List[Tuple[int, int]]], constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
 | 
			
		||||
          "def pad(a: array, pad_width: Union[int, Tuple[int], Tuple[int, int], List[Tuple[int, int]]], mode: Literal['constant', 'edge'] = 'constant', constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        Pad an array with a constant value
 | 
			
		||||
 | 
			
		||||
@@ -2878,6 +2880,9 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
              of integers is passed then ``(before_i, after_i)`` are all the same.
 | 
			
		||||
              If a single integer or tuple with a single integer is passed then
 | 
			
		||||
              all axes are extended by the same number on each side.
 | 
			
		||||
            mode: Padding mode. One of the following strings:
 | 
			
		||||
              "constant" (default): Pads with a constant value.
 | 
			
		||||
              "edge": Pads with the edge values of array.
 | 
			
		||||
            constant_value (array or scalar, optional): Optional constant value
 | 
			
		||||
              to pad the edges of the array with.
 | 
			
		||||
 | 
			
		||||
@@ -3155,7 +3160,8 @@ void init_ops(nb::module_& m) {
 | 
			
		||||
          } else { // Even sizes use asymmetric padding
 | 
			
		||||
            int pad_l = wt.size() / 2;
 | 
			
		||||
            int pad_r = std::max(0, pad_l - 1);
 | 
			
		||||
            in = pad(in, {{0, 0}, {pad_l, pad_r}, {0, 0}}, array(0), s);
 | 
			
		||||
            in = pad(
 | 
			
		||||
                in, {{0, 0}, {pad_l, pad_r}, {0, 0}}, array(0), "constant", s);
 | 
			
		||||
          }
 | 
			
		||||
 | 
			
		||||
        } else {
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user