mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-03 07:41:13 +08:00
Added Symmetric and reflect modes to pad
This commit is contained in:
parent
5a1a5d5ed1
commit
4caeb05c64
90
mlx/ops.cpp
90
mlx/ops.cpp
@ -1248,6 +1248,10 @@ array pad(
|
|||||||
{a, astype(pad_value, a.dtype(), s)});
|
{a, astype(pad_value, a.dtype(), s)});
|
||||||
} else if (mode == "edge") {
|
} else if (mode == "edge") {
|
||||||
return edge_pad(a, axes, low_pad_size, high_pad_size, out_shape, s);
|
return edge_pad(a, axes, low_pad_size, high_pad_size, out_shape, s);
|
||||||
|
} else if (mode == "symmetric") {
|
||||||
|
return symmetric_pad(a, axes, low_pad_size, high_pad_size, out_shape, s);
|
||||||
|
} else if (mode == "reflect") {
|
||||||
|
return reflect_pad(a, axes, low_pad_size, high_pad_size, out_shape, s);
|
||||||
} else {
|
} else {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "Invalid padding mode (" << mode << ") passed to pad";
|
msg << "Invalid padding mode (" << mode << ") passed to pad";
|
||||||
@ -5116,4 +5120,90 @@ array contiguous(
|
|||||||
{a});
|
{a});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array symmetric_pad(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
const Shape& low_pad_size,
|
||||||
|
const Shape& high_pad_size,
|
||||||
|
const Shape& out_shape,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
array out = zeros(out_shape, a.dtype(), s);
|
||||||
|
auto stops = a.shape();
|
||||||
|
for (int i = 0; i < stops.size(); i++) {
|
||||||
|
stops[i] += low_pad_size[i];
|
||||||
|
}
|
||||||
|
// Copy over values from the unpadded array
|
||||||
|
array padded = slice_update(out, a, low_pad_size, stops, s);
|
||||||
|
|
||||||
|
for (int axis = 0; axis < a.ndim(); axis++) {
|
||||||
|
if (low_pad_size[axis] > 0) {
|
||||||
|
Shape starts(a.ndim(), 0);
|
||||||
|
starts[axis] = 1;
|
||||||
|
auto stops = a.shape();
|
||||||
|
stops[axis] = low_pad_size[axis] + 1;
|
||||||
|
array edge_value = slice(padded, starts, stops, s);
|
||||||
|
|
||||||
|
starts[axis] = 0;
|
||||||
|
stops[axis] = low_pad_size[axis];
|
||||||
|
padded = slice_update(padded, edge_value, starts, stops, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (high_pad_size[axis] > 0) {
|
||||||
|
Shape starts(a.ndim(), 0);
|
||||||
|
starts[axis] = -high_pad_size[axis] - 1;
|
||||||
|
auto stops = out.shape();
|
||||||
|
stops[axis] = -high_pad_size[axis];
|
||||||
|
array edge_value = slice(padded, starts, stops, s);
|
||||||
|
|
||||||
|
starts[axis] = -high_pad_size[axis];
|
||||||
|
stops[axis] = out.shape(axis);
|
||||||
|
padded = slice_update(padded, edge_value, starts, stops, s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return padded;
|
||||||
|
}
|
||||||
|
|
||||||
|
array reflect_pad(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
const Shape& low_pad_size,
|
||||||
|
const Shape& high_pad_size,
|
||||||
|
const Shape& out_shape,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
array out = zeros(out_shape, a.dtype(), s);
|
||||||
|
auto stops = a.shape();
|
||||||
|
for (int i = 0; i < stops.size(); i++) {
|
||||||
|
stops[i] += low_pad_size[i];
|
||||||
|
}
|
||||||
|
// Copy over values from the unpadded array
|
||||||
|
array padded = slice_update(out, a, low_pad_size, stops, s);
|
||||||
|
|
||||||
|
for (int axis = 0; axis < a.ndim(); axis++) {
|
||||||
|
if (low_pad_size[axis] > 0) {
|
||||||
|
Shape starts(a.ndim(), 0);
|
||||||
|
starts[axis] = 1;
|
||||||
|
auto stops = a.shape();
|
||||||
|
stops[axis] = low_pad_size[axis] + 1;
|
||||||
|
array edge_value = slice(padded, starts, stops, s);
|
||||||
|
|
||||||
|
starts[axis] = 0;
|
||||||
|
stops[axis] = low_pad_size[axis];
|
||||||
|
padded = slice_update(padded, edge_value, starts, stops, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (high_pad_size[axis] > 0) {
|
||||||
|
Shape starts(a.ndim(), 0);
|
||||||
|
starts[axis] = -high_pad_size[axis] - 1;
|
||||||
|
auto stops = out.shape();
|
||||||
|
stops[axis] = -high_pad_size[axis];
|
||||||
|
array edge_value = slice(padded, starts, stops, s);
|
||||||
|
|
||||||
|
starts[axis] = -high_pad_size[axis];
|
||||||
|
stops[axis] = out.shape(axis);
|
||||||
|
padded = slice_update(padded, edge_value, starts, stops, s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return padded;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -3102,27 +3102,29 @@ void init_ops(nb::module_& m) {
|
|||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
"def pad(a: array, pad_width: Union[int, tuple[int], tuple[int, int], list[tuple[int, int]]], mode: Literal['constant', 'edge'] = 'constant', constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
|
"def pad(a: array, pad_width: Union[int, tuple[int], tuple[int, int], list[tuple[int, int]]], mode: Literal['constant', 'edge', 'symmetric', 'reflect'] = 'constant', constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Pad an array with a constant value
|
Pad an array with a constant value or using other modes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
a (array): Input array.
|
a (array): Input array.
|
||||||
pad_width (int, tuple(int), tuple(int, int) or list(tuple(int, int))): Number of padded
|
pad_width (int, tuple(int), tuple(int, int) or list(tuple(int, int))): Number of padded
|
||||||
values to add to the edges of each axis:``((before_1, after_1),
|
values to add to the edges of each axis:``((before_1, after_1),
|
||||||
(before_2, after_2), ..., (before_N, after_N))``. If a single pair
|
(before_2, after_2), ..., (before_N, after_N))``. If a single pair
|
||||||
of integers is passed then ``(before_i, after_i)`` are all the same.
|
of integers is passed then ``(before_i, after_i)`` are all the same.
|
||||||
If a single integer or tuple with a single integer is passed then
|
If a single integer or tuple with a single integer is passed then
|
||||||
all axes are extended by the same number on each side.
|
all axes are extended by the same number on each side.
|
||||||
mode: Padding mode. One of the following strings:
|
mode: Padding mode. One of the following strings:
|
||||||
"constant" (default): Pads with a constant value.
|
"constant" (default): Pads with a constant value.
|
||||||
"edge": Pads with the edge values of array.
|
"edge": Pads with the edge values of array.
|
||||||
constant_value (array or scalar, optional): Optional constant value
|
"symmetric": Pads with the reflection of the array, including the edge values.
|
||||||
to pad the edges of the array with.
|
"reflect": Pads with the reflection of the array, excluding the edge values.
|
||||||
|
constant_value (array or scalar, optional): Optional constant value
|
||||||
Returns:
|
to pad the edges of the array with.
|
||||||
array: The padded array.
|
|
||||||
)pbdoc");
|
Returns:
|
||||||
|
array: The padded array.
|
||||||
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"as_strided",
|
"as_strided",
|
||||||
[](const mx::array& a,
|
[](const mx::array& a,
|
||||||
|
@ -3981,3 +3981,42 @@ TEST_CASE("test conv_transpose3d with output_padding") {
|
|||||||
{1, 2, 4, 4, 1});
|
{1, 2, 4, 4, 1});
|
||||||
CHECK(array_equal(out, expected).item<bool>());
|
CHECK(array_equal(out, expected).item<bool>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test pad modes") {
|
||||||
|
array t = array({{1, 2, 3}, {4, 5, 6}});
|
||||||
|
Shape low_pad = {1, 2};
|
||||||
|
Shape high_pad = {1, 2};
|
||||||
|
|
||||||
|
auto constant_padded =
|
||||||
|
pad(t, {0, 1}, low_pad, high_pad, array(0), "constant");
|
||||||
|
auto edge_padded = pad(t, {0, 1}, low_pad, high_pad, array(0), "edge");
|
||||||
|
auto symmetric_padded =
|
||||||
|
pad(t, {0, 1}, low_pad, high_pad, array(0), "symmetric");
|
||||||
|
auto reflect_padded = pad(t, {0, 1}, low_pad, high_pad, array(0), "reflect");
|
||||||
|
|
||||||
|
CHECK(constant_padded.shape() == Shape{4, 7});
|
||||||
|
CHECK(edge_padded.shape() == Shape{4, 7});
|
||||||
|
CHECK(symmetric_padded.shape() == Shape{4, 7});
|
||||||
|
CHECK(reflect_padded.shape() == Shape{4, 7});
|
||||||
|
|
||||||
|
CHECK(array_equal(
|
||||||
|
constant_padded, array({0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 2, 1, 0, 0,
|
||||||
|
4, 5, 6, 5, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0}))
|
||||||
|
.item<bool>());
|
||||||
|
|
||||||
|
CHECK(array_equal(
|
||||||
|
edge_padded, array({1, 1, 1, 1, 2, 3, 3, 4, 5, 6, 6, 5, 4, 4,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}))
|
||||||
|
|
||||||
|
.item<bool>());
|
||||||
|
|
||||||
|
CHECK(array_equal(
|
||||||
|
symmetric_padded, array({1, 1, 2, 3, 3, 2, 1, 1, 2, 3, 3, 2, 1, 1,
|
||||||
|
4, 5, 6, 6, 5, 4, 4, 4, 5, 6, 6, 5, 4, 4}))
|
||||||
|
.item<bool>());
|
||||||
|
|
||||||
|
CHECK(array_equal(
|
||||||
|
reflect_padded, array({2, 1, 1, 2, 3, 3, 2, 1, 2, 3, 3, 2, 1, 1,
|
||||||
|
5, 4, 4, 5, 6, 6, 5, 5, 4, 4, 5, 6, 6, 5}))
|
||||||
|
.item<bool>());
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user