mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Added Symmetric and reflect modes to pad
This commit is contained in:
@@ -3981,3 +3981,42 @@ TEST_CASE("test conv_transpose3d with output_padding") {
|
||||
{1, 2, 4, 4, 1});
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test pad modes") {
|
||||
array t = array({{1, 2, 3}, {4, 5, 6}});
|
||||
Shape low_pad = {1, 2};
|
||||
Shape high_pad = {1, 2};
|
||||
|
||||
auto constant_padded =
|
||||
pad(t, {0, 1}, low_pad, high_pad, array(0), "constant");
|
||||
auto edge_padded = pad(t, {0, 1}, low_pad, high_pad, array(0), "edge");
|
||||
auto symmetric_padded =
|
||||
pad(t, {0, 1}, low_pad, high_pad, array(0), "symmetric");
|
||||
auto reflect_padded = pad(t, {0, 1}, low_pad, high_pad, array(0), "reflect");
|
||||
|
||||
CHECK(constant_padded.shape() == Shape{4, 7});
|
||||
CHECK(edge_padded.shape() == Shape{4, 7});
|
||||
CHECK(symmetric_padded.shape() == Shape{4, 7});
|
||||
CHECK(reflect_padded.shape() == Shape{4, 7});
|
||||
|
||||
CHECK(array_equal(
|
||||
constant_padded, array({0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 2, 1, 0, 0,
|
||||
4, 5, 6, 5, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0}))
|
||||
.item<bool>());
|
||||
|
||||
CHECK(array_equal(
|
||||
edge_padded, array({1, 1, 1, 1, 2, 3, 3, 4, 5, 6, 6, 5, 4, 4,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}))
|
||||
|
||||
.item<bool>());
|
||||
|
||||
CHECK(array_equal(
|
||||
symmetric_padded, array({1, 1, 2, 3, 3, 2, 1, 1, 2, 3, 3, 2, 1, 1,
|
||||
4, 5, 6, 6, 5, 4, 4, 4, 5, 6, 6, 5, 4, 4}))
|
||||
.item<bool>());
|
||||
|
||||
CHECK(array_equal(
|
||||
reflect_padded, array({2, 1, 1, 2, 3, 3, 2, 1, 2, 3, 3, 2, 1, 1,
|
||||
5, 4, 4, 5, 6, 6, 5, 5, 4, 4, 5, 6, 6, 5}))
|
||||
.item<bool>());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user