Added Symmetric and reflect modes to pad

This commit is contained in:
paramthakkar123 2025-05-07 08:15:35 +05:30
parent 5a1a5d5ed1
commit 4caeb05c64
3 changed files with 151 additions and 20 deletions

View File

@ -1248,6 +1248,10 @@ array pad(
{a, astype(pad_value, a.dtype(), s)});
} else if (mode == "edge") {
return edge_pad(a, axes, low_pad_size, high_pad_size, out_shape, s);
} else if (mode == "symmetric") {
return symmetric_pad(a, axes, low_pad_size, high_pad_size, out_shape, s);
} else if (mode == "reflect") {
return reflect_pad(a, axes, low_pad_size, high_pad_size, out_shape, s);
} else {
std::ostringstream msg;
msg << "Invalid padding mode (" << mode << ") passed to pad";
@ -5116,4 +5120,90 @@ array contiguous(
{a});
}
array symmetric_pad(
const array& a,
const std::vector<int>& axes,
const Shape& low_pad_size,
const Shape& high_pad_size,
const Shape& out_shape,
StreamOrDevice s /* = {} */) {
array out = zeros(out_shape, a.dtype(), s);
auto stops = a.shape();
for (int i = 0; i < stops.size(); i++) {
stops[i] += low_pad_size[i];
}
// Copy over values from the unpadded array
array padded = slice_update(out, a, low_pad_size, stops, s);
for (int axis = 0; axis < a.ndim(); axis++) {
if (low_pad_size[axis] > 0) {
Shape starts(a.ndim(), 0);
starts[axis] = 1;
auto stops = a.shape();
stops[axis] = low_pad_size[axis] + 1;
array edge_value = slice(padded, starts, stops, s);
starts[axis] = 0;
stops[axis] = low_pad_size[axis];
padded = slice_update(padded, edge_value, starts, stops, s);
}
if (high_pad_size[axis] > 0) {
Shape starts(a.ndim(), 0);
starts[axis] = -high_pad_size[axis] - 1;
auto stops = out.shape();
stops[axis] = -high_pad_size[axis];
array edge_value = slice(padded, starts, stops, s);
starts[axis] = -high_pad_size[axis];
stops[axis] = out.shape(axis);
padded = slice_update(padded, edge_value, starts, stops, s);
}
}
return padded;
}
array reflect_pad(
const array& a,
const std::vector<int>& axes,
const Shape& low_pad_size,
const Shape& high_pad_size,
const Shape& out_shape,
StreamOrDevice s /* = {} */) {
array out = zeros(out_shape, a.dtype(), s);
auto stops = a.shape();
for (int i = 0; i < stops.size(); i++) {
stops[i] += low_pad_size[i];
}
// Copy over values from the unpadded array
array padded = slice_update(out, a, low_pad_size, stops, s);
for (int axis = 0; axis < a.ndim(); axis++) {
if (low_pad_size[axis] > 0) {
Shape starts(a.ndim(), 0);
starts[axis] = 1;
auto stops = a.shape();
stops[axis] = low_pad_size[axis] + 1;
array edge_value = slice(padded, starts, stops, s);
starts[axis] = 0;
stops[axis] = low_pad_size[axis];
padded = slice_update(padded, edge_value, starts, stops, s);
}
if (high_pad_size[axis] > 0) {
Shape starts(a.ndim(), 0);
starts[axis] = -high_pad_size[axis] - 1;
auto stops = out.shape();
stops[axis] = -high_pad_size[axis];
array edge_value = slice(padded, starts, stops, s);
starts[axis] = -high_pad_size[axis];
stops[axis] = out.shape(axis);
padded = slice_update(padded, edge_value, starts, stops, s);
}
}
return padded;
}
} // namespace mlx::core

View File

@ -3102,27 +3102,29 @@ void init_ops(nb::module_& m) {
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def pad(a: array, pad_width: Union[int, tuple[int], tuple[int, int], list[tuple[int, int]]], mode: Literal['constant', 'edge'] = 'constant', constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
"def pad(a: array, pad_width: Union[int, tuple[int], tuple[int, int], list[tuple[int, int]]], mode: Literal['constant', 'edge', 'symmetric', 'reflect'] = 'constant', constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Pad an array with a constant value
Args:
a (array): Input array.
pad_width (int, tuple(int), tuple(int, int) or list(tuple(int, int))): Number of padded
values to add to the edges of each axis:``((before_1, after_1),
(before_2, after_2), ..., (before_N, after_N))``. If a single pair
of integers is passed then ``(before_i, after_i)`` are all the same.
If a single integer or tuple with a single integer is passed then
all axes are extended by the same number on each side.
mode: Padding mode. One of the following strings:
"constant" (default): Pads with a constant value.
"edge": Pads with the edge values of array.
constant_value (array or scalar, optional): Optional constant value
to pad the edges of the array with.
Returns:
array: The padded array.
)pbdoc");
Pad an array with a constant value or using other modes.
Args:
a (array): Input array.
pad_width (int, tuple(int), tuple(int, int) or list(tuple(int, int))): Number of padded
values to add to the edges of each axis:``((before_1, after_1),
(before_2, after_2), ..., (before_N, after_N))``. If a single pair
of integers is passed then ``(before_i, after_i)`` are all the same.
If a single integer or tuple with a single integer is passed then
all axes are extended by the same number on each side.
mode: Padding mode. One of the following strings:
"constant" (default): Pads with a constant value.
"edge": Pads with the edge values of array.
"symmetric": Pads with the reflection of the array, including the edge values.
"reflect": Pads with the reflection of the array, excluding the edge values.
constant_value (array or scalar, optional): Optional constant value
to pad the edges of the array with.
Returns:
array: The padded array.
)pbdoc");
m.def(
"as_strided",
[](const mx::array& a,

View File

@ -3981,3 +3981,42 @@ TEST_CASE("test conv_transpose3d with output_padding") {
{1, 2, 4, 4, 1});
CHECK(array_equal(out, expected).item<bool>());
}
TEST_CASE("test pad modes") {
array t = array({{1, 2, 3}, {4, 5, 6}});
Shape low_pad = {1, 2};
Shape high_pad = {1, 2};
auto constant_padded =
pad(t, {0, 1}, low_pad, high_pad, array(0), "constant");
auto edge_padded = pad(t, {0, 1}, low_pad, high_pad, array(0), "edge");
auto symmetric_padded =
pad(t, {0, 1}, low_pad, high_pad, array(0), "symmetric");
auto reflect_padded = pad(t, {0, 1}, low_pad, high_pad, array(0), "reflect");
CHECK(constant_padded.shape() == Shape{4, 7});
CHECK(edge_padded.shape() == Shape{4, 7});
CHECK(symmetric_padded.shape() == Shape{4, 7});
CHECK(reflect_padded.shape() == Shape{4, 7});
CHECK(array_equal(
constant_padded, array({0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 2, 1, 0, 0,
4, 5, 6, 5, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0}))
.item<bool>());
CHECK(array_equal(
edge_padded, array({1, 1, 1, 1, 2, 3, 3, 4, 5, 6, 6, 5, 4, 4,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}))
.item<bool>());
CHECK(array_equal(
symmetric_padded, array({1, 1, 2, 3, 3, 2, 1, 1, 2, 3, 3, 2, 1, 1,
4, 5, 6, 6, 5, 4, 4, 4, 5, 6, 6, 5, 4, 4}))
.item<bool>());
CHECK(array_equal(
reflect_padded, array({2, 1, 1, 2, 3, 3, 2, 1, 2, 3, 3, 2, 1, 1,
5, 4, 4, 5, 6, 6, 5, 5, 4, 4, 5, 6, 6, 5}))
.item<bool>());
}