mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
Add inner / outer op (#348)
* inner / outer impl * python tests * ops list and ack * updated descriptions * use test helper * removed dtype check and flatten outer to 1-D * updated docs * just use the reshape to flatten
This commit is contained in:
@@ -2315,3 +2315,47 @@ TEST_CASE("tensordot") {
|
||||
{3, 6});
|
||||
CHECK(array_equal(z, expected).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("outer") {
|
||||
auto x = arange(1.0, 5.0);
|
||||
auto y = arange(1.0, 4.0);
|
||||
auto z = outer(x, y);
|
||||
auto expected = array(
|
||||
{1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0}, {4, 3});
|
||||
CHECK(array_equal(z, expected).item<bool>());
|
||||
|
||||
x = ones({5});
|
||||
y = linspace(-2., 2., 5);
|
||||
z = outer(x, y);
|
||||
expected = array(
|
||||
{-2., -1., 0., 1., 2., -2., -1., 0., 1., 2., -2., -1., 0.,
|
||||
1., 2., -2., -1., 0., 1., 2., -2., -1., 0., 1., 2.},
|
||||
{5, 5});
|
||||
CHECK(array_equal(z, expected).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("inner") {
|
||||
CHECK_THROWS_AS(
|
||||
inner(reshape(arange(5.), {1, 5}), reshape(arange(6.), {2, 3})),
|
||||
std::invalid_argument);
|
||||
auto x = array({1., 2., 3.});
|
||||
auto y = array({0., 1., 0.});
|
||||
auto z = inner(x, y);
|
||||
CHECK_EQ(z.item<float>(), 2.f);
|
||||
|
||||
x = reshape(arange(24.), {2, 3, 4});
|
||||
y = arange(4.);
|
||||
z = inner(x, y);
|
||||
auto expected = array({14., 38., 62., 86., 110., 134.}, {2, 3});
|
||||
CHECK(array_equal(z, expected).item<bool>());
|
||||
|
||||
x = reshape(arange(2.), {1, 1, 2});
|
||||
y = reshape(arange(6.), {3, 2});
|
||||
z = inner(x, y);
|
||||
expected = array({1., 3., 5.}, {1, 1, 3});
|
||||
CHECK(array_equal(z, expected).item<bool>());
|
||||
|
||||
z = inner(eye(2), array(7.));
|
||||
expected = array({7., 0., 0., 7.}, {2, 2});
|
||||
CHECK(array_equal(z, expected).item<bool>());
|
||||
}
|
||||
|
Reference in New Issue
Block a user