Add Tensordot op (#344)

This commit is contained in:
Diogo
2024-01-02 20:15:00 -05:00
committed by GitHub
parent af66a09bde
commit 0782a4573a
7 changed files with 198 additions and 1 deletions

View File

@@ -2277,4 +2277,41 @@ TEST_CASE("test repeat") {
// negative repeats
CHECK_THROWS_AS(repeat(data_3, -3, 0), std::invalid_argument);
}
TEST_CASE("tensordot") {
auto x = reshape(arange(60.), {3, 4, 5});
auto y = reshape(arange(24.), {4, 3, 2});
auto z = tensordot(x, y, {{1, 0}, {0, 1}});
auto expected = array(
{4400, 4730, 4532, 4874, 4664, 5018, 4796, 5162, 4928, 5306}, {5, 2});
CHECK(array_equal(z, expected).item<bool>());
x = reshape(arange(360.), {3, 4, 5, 6});
y = reshape(arange(360.), {6, 4, 5, 3});
CHECK_THROWS_AS(
tensordot(x, y, {{2, 1, 3}, {1, 2, 0}}), std::invalid_argument);
x = reshape(arange(60.), {3, 4, 5});
y = reshape(arange(120.), {4, 5, 6});
z = tensordot(x, y, 2);
expected = array(
{14820.,
15010.,
15200.,
15390.,
15580.,
15770.,
37620.,
38210.,
38800.,
39390.,
39980.,
40570.,
60420.,
61410.,
62400.,
63390.,
64380.,
65370.},
{3, 6});
CHECK(array_equal(z, expected).item<bool>());
}