diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 4aa5e88b7..e174bc9b7 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1248,6 +1248,10 @@ array pad( {a, astype(pad_value, a.dtype(), s)}); } else if (mode == "edge") { return edge_pad(a, axes, low_pad_size, high_pad_size, out_shape, s); + } else if (mode == "symmetric") { + return symmetric_pad(a, axes, low_pad_size, high_pad_size, out_shape, s); + } else if (mode == "reflect") { + return reflect_pad(a, axes, low_pad_size, high_pad_size, out_shape, s); } else { std::ostringstream msg; msg << "Invalid padding mode (" << mode << ") passed to pad"; @@ -5116,4 +5120,90 @@ array contiguous( {a}); } +array symmetric_pad( + const array& a, + const std::vector& axes, + const Shape& low_pad_size, + const Shape& high_pad_size, + const Shape& out_shape, + StreamOrDevice s /* = {} */) { + array out = zeros(out_shape, a.dtype(), s); + auto stops = a.shape(); + for (int i = 0; i < stops.size(); i++) { + stops[i] += low_pad_size[i]; + } + // Copy over values from the unpadded array + array padded = slice_update(out, a, low_pad_size, stops, s); + + for (int axis = 0; axis < a.ndim(); axis++) { + if (low_pad_size[axis] > 0) { + Shape starts(a.ndim(), 0); + starts[axis] = 1; + auto stops = a.shape(); + stops[axis] = low_pad_size[axis] + 1; + array edge_value = slice(padded, starts, stops, s); + + starts[axis] = 0; + stops[axis] = low_pad_size[axis]; + padded = slice_update(padded, edge_value, starts, stops, s); + } + + if (high_pad_size[axis] > 0) { + Shape starts(a.ndim(), 0); + starts[axis] = -high_pad_size[axis] - 1; + auto stops = out.shape(); + stops[axis] = -high_pad_size[axis]; + array edge_value = slice(padded, starts, stops, s); + + starts[axis] = -high_pad_size[axis]; + stops[axis] = out.shape(axis); + padded = slice_update(padded, edge_value, starts, stops, s); + } + } + return padded; +} + +array reflect_pad( + const array& a, + const std::vector& axes, + const Shape& low_pad_size, + const Shape& high_pad_size, + const Shape& out_shape, + StreamOrDevice s /* = {} */) { + array out = zeros(out_shape, a.dtype(), s); + auto stops = a.shape(); + for (int i = 0; i < stops.size(); i++) { + stops[i] += low_pad_size[i]; + } + // Copy over values from the unpadded array + array padded = slice_update(out, a, low_pad_size, stops, s); + + for (int axis = 0; axis < a.ndim(); axis++) { + if (low_pad_size[axis] > 0) { + Shape starts(a.ndim(), 0); + starts[axis] = 1; + auto stops = a.shape(); + stops[axis] = low_pad_size[axis] + 1; + array edge_value = slice(padded, starts, stops, s); + + starts[axis] = 0; + stops[axis] = low_pad_size[axis]; + padded = slice_update(padded, edge_value, starts, stops, s); + } + + if (high_pad_size[axis] > 0) { + Shape starts(a.ndim(), 0); + starts[axis] = -high_pad_size[axis] - 1; + auto stops = out.shape(); + stops[axis] = -high_pad_size[axis]; + array edge_value = slice(padded, starts, stops, s); + + starts[axis] = -high_pad_size[axis]; + stops[axis] = out.shape(axis); + padded = slice_update(padded, edge_value, starts, stops, s); + } + } + return padded; +} + } // namespace mlx::core diff --git a/python/src/ops.cpp b/python/src/ops.cpp index a1e77d681..9e81adf4e 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3102,27 +3102,29 @@ void init_ops(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def pad(a: array, pad_width: Union[int, tuple[int], tuple[int, int], list[tuple[int, int]]], mode: Literal['constant', 'edge'] = 'constant', constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array"), + "def pad(a: array, pad_width: Union[int, tuple[int], tuple[int, int], list[tuple[int, int]]], mode: Literal['constant', 'edge', 'symmetric', 'reflect'] = 'constant', constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - Pad an array with a constant value - - Args: - a (array): Input array. - pad_width (int, tuple(int), tuple(int, int) or list(tuple(int, int))): Number of padded - values to add to the edges of each axis:``((before_1, after_1), - (before_2, after_2), ..., (before_N, after_N))``. If a single pair - of integers is passed then ``(before_i, after_i)`` are all the same. - If a single integer or tuple with a single integer is passed then - all axes are extended by the same number on each side. - mode: Padding mode. One of the following strings: - "constant" (default): Pads with a constant value. - "edge": Pads with the edge values of array. - constant_value (array or scalar, optional): Optional constant value - to pad the edges of the array with. - - Returns: - array: The padded array. - )pbdoc"); + Pad an array with a constant value or using other modes. + + Args: + a (array): Input array. + pad_width (int, tuple(int), tuple(int, int) or list(tuple(int, int))): Number of padded + values to add to the edges of each axis:``((before_1, after_1), + (before_2, after_2), ..., (before_N, after_N))``. If a single pair + of integers is passed then ``(before_i, after_i)`` are all the same. + If a single integer or tuple with a single integer is passed then + all axes are extended by the same number on each side. + mode: Padding mode. One of the following strings: + "constant" (default): Pads with a constant value. + "edge": Pads with the edge values of array. + "symmetric": Pads with the reflection of the array, including the edge values. + "reflect": Pads with the reflection of the array, excluding the edge values. + constant_value (array or scalar, optional): Optional constant value + to pad the edges of the array with. + + Returns: + array: The padded array. + )pbdoc"); m.def( "as_strided", [](const mx::array& a, diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 5e2bae5a0..208f11251 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -3981,3 +3981,42 @@ TEST_CASE("test conv_transpose3d with output_padding") { {1, 2, 4, 4, 1}); CHECK(array_equal(out, expected).item()); } + +TEST_CASE("test pad modes") { + array t = array({{1, 2, 3}, {4, 5, 6}}); + Shape low_pad = {1, 2}; + Shape high_pad = {1, 2}; + + auto constant_padded = + pad(t, {0, 1}, low_pad, high_pad, array(0), "constant"); + auto edge_padded = pad(t, {0, 1}, low_pad, high_pad, array(0), "edge"); + auto symmetric_padded = + pad(t, {0, 1}, low_pad, high_pad, array(0), "symmetric"); + auto reflect_padded = pad(t, {0, 1}, low_pad, high_pad, array(0), "reflect"); + + CHECK(constant_padded.shape() == Shape{4, 7}); + CHECK(edge_padded.shape() == Shape{4, 7}); + CHECK(symmetric_padded.shape() == Shape{4, 7}); + CHECK(reflect_padded.shape() == Shape{4, 7}); + + CHECK(array_equal( + constant_padded, array({0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 2, 1, 0, 0, + 4, 5, 6, 5, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0})) + .item()); + + CHECK(array_equal( + edge_padded, array({1, 1, 1, 1, 2, 3, 3, 4, 5, 6, 6, 5, 4, 4, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})) + + .item()); + + CHECK(array_equal( + symmetric_padded, array({1, 1, 2, 3, 3, 2, 1, 1, 2, 3, 3, 2, 1, 1, + 4, 5, 6, 6, 5, 4, 4, 4, 5, 6, 6, 5, 4, 4})) + .item()); + + CHECK(array_equal( + reflect_padded, array({2, 1, 1, 2, 3, 3, 2, 1, 2, 3, 3, 2, 1, 1, + 5, 4, 4, 5, 6, 6, 5, 5, 4, 4, 5, 6, 6, 5})) + .item()); +}