Add inner / outer op (#348)

* inner / outer impl

* python tests

* ops list and ack

* updated descriptions

* use test helper

* removed dtype check and flatten outer to 1-D

* updated docs

* just use the reshape to flatten
This commit is contained in:
Diogo
2024-01-07 12:01:09 -05:00
committed by GitHub
parent 6ea6b4258d
commit 449b43762e
7 changed files with 140 additions and 5 deletions

View File

@@ -2315,3 +2315,47 @@ TEST_CASE("tensordot") {
{3, 6});
CHECK(array_equal(z, expected).item<bool>());
}
TEST_CASE("outer") {
auto x = arange(1.0, 5.0);
auto y = arange(1.0, 4.0);
auto z = outer(x, y);
auto expected = array(
{1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0}, {4, 3});
CHECK(array_equal(z, expected).item<bool>());
x = ones({5});
y = linspace(-2., 2., 5);
z = outer(x, y);
expected = array(
{-2., -1., 0., 1., 2., -2., -1., 0., 1., 2., -2., -1., 0.,
1., 2., -2., -1., 0., 1., 2., -2., -1., 0., 1., 2.},
{5, 5});
CHECK(array_equal(z, expected).item<bool>());
}
TEST_CASE("inner") {
CHECK_THROWS_AS(
inner(reshape(arange(5.), {1, 5}), reshape(arange(6.), {2, 3})),
std::invalid_argument);
auto x = array({1., 2., 3.});
auto y = array({0., 1., 0.});
auto z = inner(x, y);
CHECK_EQ(z.item<float>(), 2.f);
x = reshape(arange(24.), {2, 3, 4});
y = arange(4.);
z = inner(x, y);
auto expected = array({14., 38., 62., 86., 110., 134.}, {2, 3});
CHECK(array_equal(z, expected).item<bool>());
x = reshape(arange(2.), {1, 1, 2});
y = reshape(arange(6.), {3, 2});
z = inner(x, y);
expected = array({1., 3., 5.}, {1, 1, 3});
CHECK(array_equal(z, expected).item<bool>());
z = inner(eye(2), array(7.));
expected = array({7., 0., 0., 7.}, {2, 2});
CHECK(array_equal(z, expected).item<bool>());
}