Added conv_transpose fixes

This commit is contained in:
paramthakkar123
2025-04-19 20:59:53 +05:30
parent 5f04c0f818
commit 59b934dcf9
3 changed files with 121 additions and 3 deletions

View File

@@ -3911,4 +3911,64 @@ TEST_CASE("test bitwise shift operations") {
CHECK_EQ(right_shift_bool_result.dtype(), uint8);
CHECK(array_equal(right_shift_bool_result, full({4}, 0, uint8)).item<bool>());
}
TEST_CASE("test conv_transpose1d with output_padding") {
auto in = array({1.0, 2.0, 3.0}, {1, 1, 3});
auto wt = array({1.0, 1.0}, {1, 1, 2});
int stride = 2;
int padding = 0;
int output_padding = 1;
auto out = conv_transpose1d(in, wt, stride, padding, 1, output_padding, 1);
auto expected = array({1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 0.0}, {1, 1, 7});
CHECK(array_equal(out, expected).item<bool>());
}
TEST_CASE("test conv_transpose2d with output_padding") {
auto in = array({1.0, 2.0, 3.0, 4.0}, {1, 1, 2, 2});
auto wt = array({1.0, 1.0, 1.0, 1.0}, {1, 1, 2, 2});
std::pair<int, int> stride{2, 2};
std::pair<int, int> padding{0, 0};
std::pair<int, int> output_padding{1, 1};
auto out =
conv_transpose2d(in, wt, stride, padding, {1, 1}, output_padding, 1);
auto expected = array(
{1.0,
1.0,
2.0,
2.0,
0.0,
3.0,
3.0,
4.0,
4.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0},
{1, 1, 5, 5});
CHECK(array_equal(out, expected).item<bool>());
}
TEST_CASE("test conv_transpose3d with output_padding") {
auto in = array({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}, {1, 1, 2, 2, 2});
auto wt = array({1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}, {1, 1, 2, 2, 2});
std::tuple<int, int, int> stride{2, 2, 2};
std::tuple<int, int, int> padding{0, 0, 0};
std::tuple<int, int, int> output_padding{1, 1, 1};
auto out =
conv_transpose3d(in, wt, stride, padding, {1, 1, 1}, output_padding, 1);
auto expected = array(
{1.0, 1.0, 2.0, 2.0, 0.0, 3.0, 3.0, 4.0, 4.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 5.0, 5.0, 6.0, 6.0, 0.0, 7.0, 7.0, 8.0, 8.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0},
{1, 1, 5, 5, 5});
CHECK(array_equal(out, expected).item<bool>());
}