Added Symmetric and reflect modes to pad

This commit is contained in:
paramthakkar123
2025-05-07 08:15:35 +05:30
parent 5a1a5d5ed1
commit 4caeb05c64
3 changed files with 151 additions and 20 deletions

View File

@@ -3981,3 +3981,42 @@ TEST_CASE("test conv_transpose3d with output_padding") {
{1, 2, 4, 4, 1});
CHECK(array_equal(out, expected).item<bool>());
}
TEST_CASE("test pad modes") {
array t = array({{1, 2, 3}, {4, 5, 6}});
Shape low_pad = {1, 2};
Shape high_pad = {1, 2};
auto constant_padded =
pad(t, {0, 1}, low_pad, high_pad, array(0), "constant");
auto edge_padded = pad(t, {0, 1}, low_pad, high_pad, array(0), "edge");
auto symmetric_padded =
pad(t, {0, 1}, low_pad, high_pad, array(0), "symmetric");
auto reflect_padded = pad(t, {0, 1}, low_pad, high_pad, array(0), "reflect");
CHECK(constant_padded.shape() == Shape{4, 7});
CHECK(edge_padded.shape() == Shape{4, 7});
CHECK(symmetric_padded.shape() == Shape{4, 7});
CHECK(reflect_padded.shape() == Shape{4, 7});
CHECK(array_equal(
constant_padded, array({0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 2, 1, 0, 0,
4, 5, 6, 5, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0}))
.item<bool>());
CHECK(array_equal(
edge_padded, array({1, 1, 1, 1, 2, 3, 3, 4, 5, 6, 6, 5, 4, 4,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}))
.item<bool>());
CHECK(array_equal(
symmetric_padded, array({1, 1, 2, 3, 3, 2, 1, 1, 2, 3, 3, 2, 1, 1,
4, 5, 6, 6, 5, 4, 4, 4, 5, 6, 6, 5, 4, 4}))
.item<bool>());
CHECK(array_equal(
reflect_padded, array({2, 1, 1, 2, 3, 3, 2, 1, 2, 3, 3, 2, 1, 1,
5, 4, 4, 5, 6, 6, 5, 5, 4, 4, 5, 6, 6, 5}))
.item<bool>());
}