From 635ccd9e25e39699f6f555b478dc7a29c98c8bc3 Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Tue, 6 Aug 2024 11:23:10 -0700 Subject: [PATCH] Add "edge" mode to mx.pad (#1309) * Add edge padding mode * fix pad in pooling * string arg instead of enum --- mlx/ops.cpp | 76 ++++++++++++++++++++++++++++++--- mlx/ops.h | 4 ++ mlx/primitives.cpp | 8 ++-- python/mlx/nn/layers/pooling.py | 6 ++- python/src/ops.cpp | 20 ++++++--- python/tests/test_ops.py | 6 +++ 6 files changed, 102 insertions(+), 18 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 26ac3e98f..53f612642 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1000,6 +1000,51 @@ array tile( return reshape(x, final_shape, s); } +array edge_pad( + const array& a, + const std::vector& axes, + const std::vector& low_pad_size, + const std::vector& high_pad_size, + const std::vector& out_shape, + StreamOrDevice s /* = {}*/) { + array out = zeros(out_shape, a.dtype(), s); + auto stops = a.shape(); + for (int i = 0; i < stops.size(); i++) { + stops[i] += low_pad_size[i]; + } + // Copy over values from the unpadded array + array padded = slice_update(out, a, low_pad_size, stops, s); + + for (int axis = 0; axis < a.ndim(); axis++) { + if (low_pad_size[axis] > 0) { + std::vector starts(a.ndim(), 0); + starts[axis] = low_pad_size[axis]; + auto stops = out.shape(); + stops[axis] = low_pad_size[axis] + 1; + // Fetch edge values + array edge_value = slice(padded, starts, stops, s); + + starts[axis] = 0; + stops[axis] = low_pad_size[axis]; + // Update edge values in the padded array + padded = slice_update(padded, edge_value, starts, stops, s); + } + + if (high_pad_size[axis] > 0) { + std::vector starts(a.ndim(), 0); + starts[axis] = -high_pad_size[axis] - 1; + auto stops = out.shape(); + stops[axis] = -high_pad_size[axis]; + array edge_value = slice(padded, starts, stops, s); + + starts[axis] = -high_pad_size[axis]; + stops[axis] = out.shape(axis); + padded = slice_update(padded, edge_value, starts, stops, s); + } + } + return padded; +} + /** Pad an array with a constant value */ array pad( const array& a, @@ -1007,6 +1052,7 @@ array pad( const std::vector& low_pad_size, const std::vector& high_pad_size, const array& pad_value /*= array(0)*/, + const std::string mode /*= "constant"*/, StreamOrDevice s /* = {}*/) { if (axes.size() != low_pad_size.size() || axes.size() != high_pad_size.size()) { @@ -1038,11 +1084,19 @@ array pad( out_shape[ax] += low_pad_size[i] + high_pad_size[i]; } - return array( - out_shape, - a.dtype(), - std::make_shared(to_stream(s), axes, low_pad_size, high_pad_size), - {a, astype(pad_value, a.dtype(), s)}); + if (mode == "constant") { + return array( + out_shape, + a.dtype(), + std::make_shared(to_stream(s), axes, low_pad_size, high_pad_size), + {a, astype(pad_value, a.dtype(), s)}); + } else if (mode == "edge") { + return edge_pad(a, axes, low_pad_size, high_pad_size, out_shape, s); + } else { + std::ostringstream msg; + msg << "Invalid padding mode (" << mode << ") passed to pad"; + throw std::invalid_argument(msg.str()); + } } /** Pad an array with a constant value along all axes */ @@ -1050,6 +1104,7 @@ array pad( const array& a, const std::vector>& pad_width, const array& pad_value /*= array(0)*/, + const std::string mode /*= "constant"*/, StreamOrDevice s /*= {}*/) { std::vector axes(a.ndim(), 0); std::iota(axes.begin(), axes.end(), 0); @@ -1062,27 +1117,34 @@ array pad( highs.push_back(pads.second); } - return pad(a, axes, lows, highs, pad_value, s); + return pad(a, axes, lows, highs, pad_value, mode, s); } array pad( const array& a, const std::pair& pad_width, const array& pad_value /*= array(0)*/, + const std::string mode /*= "constant"*/, StreamOrDevice s /*= {}*/) { return pad( - a, std::vector>(a.ndim(), pad_width), pad_value, s); + a, + std::vector>(a.ndim(), pad_width), + pad_value, + mode, + s); } array pad( const array& a, int pad_width, const array& pad_value /*= array(0)*/, + const std::string mode /*= "constant"*/, StreamOrDevice s /*= {}*/) { return pad( a, std::vector>(a.ndim(), {pad_width, pad_width}), pad_value, + mode, s); } diff --git a/mlx/ops.h b/mlx/ops.h index 448c62a89..2a9c2f961 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -263,6 +263,7 @@ array pad( const std::vector& low_pad_size, const std::vector& high_pad_size, const array& pad_value = array(0), + const std::string mode = "constant", StreamOrDevice s = {}); /** Pad an array with a constant value along all axes */ @@ -270,16 +271,19 @@ array pad( const array& a, const std::vector>& pad_width, const array& pad_value = array(0), + const std::string mode = "constant", StreamOrDevice s = {}); array pad( const array& a, const std::pair& pad_width, const array& pad_value = array(0), + const std::string mode = "constant", StreamOrDevice s = {}); array pad( const array& a, int pad_width, const array& pad_value = array(0), + const std::string mode = "constant", StreamOrDevice s = {}); /** Permutes the dimensions in reverse order. */ diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index ba818c99d..fb4686e84 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -855,8 +855,8 @@ array conv_weight_backward_patches( // Pad input std::vector padded_axes(in.ndim() - 2, 0); std::iota(padded_axes.begin(), padded_axes.end(), 1); - auto in_padded = - pad(in, padded_axes, padding, padding, array(0, in.dtype()), s); + auto in_padded = pad( + in, padded_axes, padding, padding, array(0, in.dtype()), "constant", s); // Resolve strided patches @@ -2289,6 +2289,7 @@ std::vector Pad::jvp( low_pad_size_, high_pad_size_, array(0, tangents[0].dtype()), + "constant", stream())}; } @@ -3740,7 +3741,8 @@ std::vector BlockMaskedMM::vjp( // Pad if needed if ((align_Y != 0) || (align_X != 0)) { - r = pad(r, {-2, -1}, {0, 0}, {align_Y, align_X}, zero, stream()); + r = pad( + r, {-2, -1}, {0, 0}, {align_Y, align_X}, zero, "constant", stream()); } // Reshape diff --git a/python/mlx/nn/layers/pooling.py b/python/mlx/nn/layers/pooling.py index 3733fd777..93ae4d8c2 100644 --- a/python/mlx/nn/layers/pooling.py +++ b/python/mlx/nn/layers/pooling.py @@ -101,7 +101,11 @@ class _Pool(Module): def __call__(self, x): if any(p[0] > 0 for p in self._padding): - x = mx.pad(x, [(0, 0)] + self._padding + [(0, 0)], self._padding_value) + x = mx.pad( + x, + [(0, 0)] + self._padding + [(0, 0)], + constant_values=self._padding_value, + ) x = _sliding_windows(x, self._kernel_size, self._stride) return self._pooling_function(x, self._axes) diff --git a/python/src/ops.cpp b/python/src/ops.cpp index eba8043f1..e21bf525c 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -2843,30 +2843,32 @@ void init_ops(nb::module_& m) { std::tuple, std::pair, std::vector>>& pad_width, + const std::string mode, const ScalarOrArray& constant_value, StreamOrDevice s) { if (auto pv = std::get_if(&pad_width); pv) { - return pad(a, *pv, to_array(constant_value), s); + return pad(a, *pv, to_array(constant_value), mode, s); } else if (auto pv = std::get_if>(&pad_width); pv) { - return pad(a, std::get<0>(*pv), to_array(constant_value), s); + return pad(a, std::get<0>(*pv), to_array(constant_value), mode, s); } else if (auto pv = std::get_if>(&pad_width); pv) { - return pad(a, *pv, to_array(constant_value), s); + return pad(a, *pv, to_array(constant_value), mode, s); } else { auto v = std::get>>(pad_width); if (v.size() == 1) { - return pad(a, v[0], to_array(constant_value), s); + return pad(a, v[0], to_array(constant_value), mode, s); } else { - return pad(a, v, to_array(constant_value), s); + return pad(a, v, to_array(constant_value), mode, s); } } }, nb::arg(), "pad_width"_a, + "mode"_a = "constant", "constant_values"_a = 0, nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def pad(a: array, pad_width: Union[int, Tuple[int], Tuple[int, int], List[Tuple[int, int]]], constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array"), + "def pad(a: array, pad_width: Union[int, Tuple[int], Tuple[int, int], List[Tuple[int, int]]], mode: Literal['constant', 'edge'] = 'constant', constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Pad an array with a constant value @@ -2878,6 +2880,9 @@ void init_ops(nb::module_& m) { of integers is passed then ``(before_i, after_i)`` are all the same. If a single integer or tuple with a single integer is passed then all axes are extended by the same number on each side. + mode: Padding mode. One of the following strings: + "constant" (default): Pads with a constant value. + "edge": Pads with the edge values of array. constant_value (array or scalar, optional): Optional constant value to pad the edges of the array with. @@ -3155,7 +3160,8 @@ void init_ops(nb::module_& m) { } else { // Even sizes use asymmetric padding int pad_l = wt.size() / 2; int pad_r = std::max(0, pad_l - 1); - in = pad(in, {{0, 0}, {pad_l, pad_r}, {0, 0}}, array(0), s); + in = pad( + in, {{0, 0}, {pad_l, pad_r}, {0, 0}}, array(0), "constant", s); } } else { diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index c3fe0866b..2b8cdfbb8 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1623,6 +1623,12 @@ class TestOps(mlx_tests.MLXTestCase): self.assertEqual(list(b_npy.shape), list(b_mlx.shape)) self.assertTrue(np.allclose(b_npy, b_mlx, atol=1e-6)) + b_npy = np.pad(a_npy, pw, mode="edge") + b_mlx = mx.pad(a_mlx, pw, mode="edge") + + self.assertEqual(list(b_npy.shape), list(b_mlx.shape)) + self.assertTrue(np.allclose(b_npy, b_mlx, atol=1e-6)) + a = mx.zeros((1, 1, 1)) self.assertEqual(mx.pad(a, 1).shape, (3, 3, 3)) self.assertEqual(mx.pad(a, (1,)).shape, (3, 3, 3))