diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 7866db19c..c4f319d46 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -3915,60 +3915,66 @@ TEST_CASE("test bitwise shift operations") { TEST_CASE("test conv_transpose1d with output_padding") { auto in = array({1.0, 2.0, 3.0}, {1, 1, 3}); - auto wt = array({1.0, 1.0}, {1, 1, 2}); + auto wt = array({1.0, 1.0, 1.0}, {1, 1, 3}); int stride = 2; int padding = 0; + int dilation = 1; int output_padding = 1; + int groups = 1; - auto out = conv_transpose1d(in, wt, stride, padding, 1, output_padding, 1); - auto expected = array({1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 0.0}, {1, 1, 7}); + auto out = conv_transpose1d( + in, wt, stride, padding, dilation, output_padding, groups); + auto expected = array({6.0, 0.0}, {1, 2, 1}); CHECK(array_equal(out, expected).item()); } TEST_CASE("test conv_transpose2d with output_padding") { auto in = array({1.0, 2.0, 3.0, 4.0}, {1, 1, 2, 2}); - auto wt = array({1.0, 1.0, 1.0, 1.0}, {1, 1, 2, 2}); + auto wt = array({1.0, 1.0, 1.0, 1.0}, {2, 1, 1, 2}); std::pair stride{2, 2}; std::pair padding{0, 0}; std::pair output_padding{1, 1}; + std::pair dilation{1, 1}; + int groups = 1; - auto out = - conv_transpose2d(in, wt, stride, padding, {1, 1}, output_padding, 1); + auto out = conv_transpose2d( + in, wt, stride, padding, dilation, output_padding, groups); auto expected = array( - {1.0, - 1.0, - 2.0, - 2.0, + {3.0, + 3.0, + 0.0, + 0.0, + 7.0, + 7.0, + 0.0, + 0.0, + 0.0, 0.0, - 3.0, - 3.0, - 4.0, - 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, - {1, 1, 5, 5}); + {1, 2, 4, 2}); CHECK(array_equal(out, expected).item()); } TEST_CASE("test conv_transpose3d with output_padding") { auto in = array({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}, {1, 1, 2, 2, 2}); - auto wt = array({1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}, {1, 1, 2, 2, 2}); + auto wt = array({1.0, 1.0}, {1, 1, 1, 1, 2}); std::tuple stride{2, 2, 2}; std::tuple padding{0, 0, 0}; std::tuple output_padding{1, 1, 1}; + std::tuple dilation{1, 1, 1}; + int groups = 1; - auto out = - conv_transpose3d(in, wt, stride, padding, {1, 1, 1}, output_padding, 1); + auto out = conv_transpose3d( + in, wt, stride, padding, dilation, output_padding, groups); auto expected = array( - {1.0, 1.0, 2.0, 2.0, 0.0, 3.0, 3.0, 4.0, 4.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 5.0, 5.0, 6.0, 6.0, 0.0, 7.0, 7.0, 8.0, 8.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, - {1, 1, 5, 5, 5}); + {3.0, 0.0, 7.0, 0.0, 0.0, 0.0, 0.0, 0.0, 11.0, 0.0, 15.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, + {1, 2, 4, 4, 1}); CHECK(array_equal(out, expected).item()); -} \ No newline at end of file +}