mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-03 07:41:13 +08:00
Added Symmetric and reflect modes to pad
This commit is contained in:
parent
5a1a5d5ed1
commit
4caeb05c64
90
mlx/ops.cpp
90
mlx/ops.cpp
@ -1248,6 +1248,10 @@ array pad(
|
||||
{a, astype(pad_value, a.dtype(), s)});
|
||||
} else if (mode == "edge") {
|
||||
return edge_pad(a, axes, low_pad_size, high_pad_size, out_shape, s);
|
||||
} else if (mode == "symmetric") {
|
||||
return symmetric_pad(a, axes, low_pad_size, high_pad_size, out_shape, s);
|
||||
} else if (mode == "reflect") {
|
||||
return reflect_pad(a, axes, low_pad_size, high_pad_size, out_shape, s);
|
||||
} else {
|
||||
std::ostringstream msg;
|
||||
msg << "Invalid padding mode (" << mode << ") passed to pad";
|
||||
@ -5116,4 +5120,90 @@ array contiguous(
|
||||
{a});
|
||||
}
|
||||
|
||||
array symmetric_pad(
|
||||
const array& a,
|
||||
const std::vector<int>& axes,
|
||||
const Shape& low_pad_size,
|
||||
const Shape& high_pad_size,
|
||||
const Shape& out_shape,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
array out = zeros(out_shape, a.dtype(), s);
|
||||
auto stops = a.shape();
|
||||
for (int i = 0; i < stops.size(); i++) {
|
||||
stops[i] += low_pad_size[i];
|
||||
}
|
||||
// Copy over values from the unpadded array
|
||||
array padded = slice_update(out, a, low_pad_size, stops, s);
|
||||
|
||||
for (int axis = 0; axis < a.ndim(); axis++) {
|
||||
if (low_pad_size[axis] > 0) {
|
||||
Shape starts(a.ndim(), 0);
|
||||
starts[axis] = 1;
|
||||
auto stops = a.shape();
|
||||
stops[axis] = low_pad_size[axis] + 1;
|
||||
array edge_value = slice(padded, starts, stops, s);
|
||||
|
||||
starts[axis] = 0;
|
||||
stops[axis] = low_pad_size[axis];
|
||||
padded = slice_update(padded, edge_value, starts, stops, s);
|
||||
}
|
||||
|
||||
if (high_pad_size[axis] > 0) {
|
||||
Shape starts(a.ndim(), 0);
|
||||
starts[axis] = -high_pad_size[axis] - 1;
|
||||
auto stops = out.shape();
|
||||
stops[axis] = -high_pad_size[axis];
|
||||
array edge_value = slice(padded, starts, stops, s);
|
||||
|
||||
starts[axis] = -high_pad_size[axis];
|
||||
stops[axis] = out.shape(axis);
|
||||
padded = slice_update(padded, edge_value, starts, stops, s);
|
||||
}
|
||||
}
|
||||
return padded;
|
||||
}
|
||||
|
||||
array reflect_pad(
|
||||
const array& a,
|
||||
const std::vector<int>& axes,
|
||||
const Shape& low_pad_size,
|
||||
const Shape& high_pad_size,
|
||||
const Shape& out_shape,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
array out = zeros(out_shape, a.dtype(), s);
|
||||
auto stops = a.shape();
|
||||
for (int i = 0; i < stops.size(); i++) {
|
||||
stops[i] += low_pad_size[i];
|
||||
}
|
||||
// Copy over values from the unpadded array
|
||||
array padded = slice_update(out, a, low_pad_size, stops, s);
|
||||
|
||||
for (int axis = 0; axis < a.ndim(); axis++) {
|
||||
if (low_pad_size[axis] > 0) {
|
||||
Shape starts(a.ndim(), 0);
|
||||
starts[axis] = 1;
|
||||
auto stops = a.shape();
|
||||
stops[axis] = low_pad_size[axis] + 1;
|
||||
array edge_value = slice(padded, starts, stops, s);
|
||||
|
||||
starts[axis] = 0;
|
||||
stops[axis] = low_pad_size[axis];
|
||||
padded = slice_update(padded, edge_value, starts, stops, s);
|
||||
}
|
||||
|
||||
if (high_pad_size[axis] > 0) {
|
||||
Shape starts(a.ndim(), 0);
|
||||
starts[axis] = -high_pad_size[axis] - 1;
|
||||
auto stops = out.shape();
|
||||
stops[axis] = -high_pad_size[axis];
|
||||
array edge_value = slice(padded, starts, stops, s);
|
||||
|
||||
starts[axis] = -high_pad_size[axis];
|
||||
stops[axis] = out.shape(axis);
|
||||
padded = slice_update(padded, edge_value, starts, stops, s);
|
||||
}
|
||||
}
|
||||
return padded;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -3102,9 +3102,9 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def pad(a: array, pad_width: Union[int, tuple[int], tuple[int, int], list[tuple[int, int]]], mode: Literal['constant', 'edge'] = 'constant', constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def pad(a: array, pad_width: Union[int, tuple[int], tuple[int, int], list[tuple[int, int]]], mode: Literal['constant', 'edge', 'symmetric', 'reflect'] = 'constant', constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Pad an array with a constant value
|
||||
Pad an array with a constant value or using other modes.
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
@ -3117,6 +3117,8 @@ void init_ops(nb::module_& m) {
|
||||
mode: Padding mode. One of the following strings:
|
||||
"constant" (default): Pads with a constant value.
|
||||
"edge": Pads with the edge values of array.
|
||||
"symmetric": Pads with the reflection of the array, including the edge values.
|
||||
"reflect": Pads with the reflection of the array, excluding the edge values.
|
||||
constant_value (array or scalar, optional): Optional constant value
|
||||
to pad the edges of the array with.
|
||||
|
||||
|
@ -3981,3 +3981,42 @@ TEST_CASE("test conv_transpose3d with output_padding") {
|
||||
{1, 2, 4, 4, 1});
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test pad modes") {
|
||||
array t = array({{1, 2, 3}, {4, 5, 6}});
|
||||
Shape low_pad = {1, 2};
|
||||
Shape high_pad = {1, 2};
|
||||
|
||||
auto constant_padded =
|
||||
pad(t, {0, 1}, low_pad, high_pad, array(0), "constant");
|
||||
auto edge_padded = pad(t, {0, 1}, low_pad, high_pad, array(0), "edge");
|
||||
auto symmetric_padded =
|
||||
pad(t, {0, 1}, low_pad, high_pad, array(0), "symmetric");
|
||||
auto reflect_padded = pad(t, {0, 1}, low_pad, high_pad, array(0), "reflect");
|
||||
|
||||
CHECK(constant_padded.shape() == Shape{4, 7});
|
||||
CHECK(edge_padded.shape() == Shape{4, 7});
|
||||
CHECK(symmetric_padded.shape() == Shape{4, 7});
|
||||
CHECK(reflect_padded.shape() == Shape{4, 7});
|
||||
|
||||
CHECK(array_equal(
|
||||
constant_padded, array({0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 2, 1, 0, 0,
|
||||
4, 5, 6, 5, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0}))
|
||||
.item<bool>());
|
||||
|
||||
CHECK(array_equal(
|
||||
edge_padded, array({1, 1, 1, 1, 2, 3, 3, 4, 5, 6, 6, 5, 4, 4,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}))
|
||||
|
||||
.item<bool>());
|
||||
|
||||
CHECK(array_equal(
|
||||
symmetric_padded, array({1, 1, 2, 3, 3, 2, 1, 1, 2, 3, 3, 2, 1, 1,
|
||||
4, 5, 6, 6, 5, 4, 4, 4, 5, 6, 6, 5, 4, 4}))
|
||||
.item<bool>());
|
||||
|
||||
CHECK(array_equal(
|
||||
reflect_padded, array({2, 1, 1, 2, 3, 3, 2, 1, 2, 3, 3, 2, 1, 1,
|
||||
5, 4, 4, 5, 6, 6, 5, 5, 4, 4, 5, 6, 6, 5}))
|
||||
.item<bool>());
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user