Added output_padding parameters in conv_transpose (#2092)

This commit is contained in:
Param Thakkar
2025-04-23 21:56:33 +05:30
committed by GitHub
parent 3836445241
commit 600e87e03c
6 changed files with 366 additions and 14 deletions

View File

@@ -3911,4 +3911,70 @@ TEST_CASE("test bitwise shift operations") {
CHECK_EQ(right_shift_bool_result.dtype(), uint8);
CHECK(array_equal(right_shift_bool_result, full({4}, 0, uint8)).item<bool>());
}
}
TEST_CASE("test conv_transpose1d with output_padding") {
auto in = array({1.0, 2.0, 3.0}, {1, 1, 3});
auto wt = array({1.0, 1.0, 1.0}, {1, 1, 3});
int stride = 2;
int padding = 0;
int dilation = 1;
int output_padding = 1;
int groups = 1;
auto out = conv_transpose1d(
in, wt, stride, padding, dilation, output_padding, groups);
auto expected = array({6.0, 0.0}, {1, 2, 1});
CHECK(array_equal(out, expected).item<bool>());
}
TEST_CASE("test conv_transpose2d with output_padding") {
auto in = array({1.0, 2.0, 3.0, 4.0}, {1, 1, 2, 2});
auto wt = array({1.0, 1.0, 1.0, 1.0}, {2, 1, 1, 2});
std::pair<int, int> stride{2, 2};
std::pair<int, int> padding{0, 0};
std::pair<int, int> output_padding{1, 1};
std::pair<int, int> dilation{1, 1};
int groups = 1;
auto out = conv_transpose2d(
in, wt, stride, padding, dilation, output_padding, groups);
auto expected = array(
{3.0,
3.0,
0.0,
0.0,
7.0,
7.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0},
{1, 2, 4, 2});
CHECK(array_equal(out, expected).item<bool>());
}
TEST_CASE("test conv_transpose3d with output_padding") {
auto in = array({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}, {1, 1, 2, 2, 2});
auto wt = array({1.0, 1.0}, {1, 1, 1, 1, 2});
std::tuple<int, int, int> stride{2, 2, 2};
std::tuple<int, int, int> padding{0, 0, 0};
std::tuple<int, int, int> output_padding{1, 1, 1};
std::tuple<int, int, int> dilation{1, 1, 1};
int groups = 1;
auto out = conv_transpose3d(
in, wt, stride, padding, dilation, output_padding, groups);
auto expected = array(
{3.0, 0.0, 7.0, 0.0, 0.0, 0.0, 0.0, 0.0, 11.0, 0.0, 15.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0},
{1, 2, 4, 4, 1});
CHECK(array_equal(out, expected).item<bool>());
}