mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
Add Tensordot op (#344)
This commit is contained in:
@@ -2277,4 +2277,41 @@ TEST_CASE("test repeat") {
|
||||
|
||||
// negative repeats
|
||||
CHECK_THROWS_AS(repeat(data_3, -3, 0), std::invalid_argument);
|
||||
}
|
||||
|
||||
TEST_CASE("tensordot") {
|
||||
auto x = reshape(arange(60.), {3, 4, 5});
|
||||
auto y = reshape(arange(24.), {4, 3, 2});
|
||||
auto z = tensordot(x, y, {{1, 0}, {0, 1}});
|
||||
auto expected = array(
|
||||
{4400, 4730, 4532, 4874, 4664, 5018, 4796, 5162, 4928, 5306}, {5, 2});
|
||||
CHECK(array_equal(z, expected).item<bool>());
|
||||
x = reshape(arange(360.), {3, 4, 5, 6});
|
||||
y = reshape(arange(360.), {6, 4, 5, 3});
|
||||
CHECK_THROWS_AS(
|
||||
tensordot(x, y, {{2, 1, 3}, {1, 2, 0}}), std::invalid_argument);
|
||||
x = reshape(arange(60.), {3, 4, 5});
|
||||
y = reshape(arange(120.), {4, 5, 6});
|
||||
z = tensordot(x, y, 2);
|
||||
expected = array(
|
||||
{14820.,
|
||||
15010.,
|
||||
15200.,
|
||||
15390.,
|
||||
15580.,
|
||||
15770.,
|
||||
37620.,
|
||||
38210.,
|
||||
38800.,
|
||||
39390.,
|
||||
39980.,
|
||||
40570.,
|
||||
60420.,
|
||||
61410.,
|
||||
62400.,
|
||||
63390.,
|
||||
64380.,
|
||||
65370.},
|
||||
{3, 6});
|
||||
CHECK(array_equal(z, expected).item<bool>());
|
||||
}
|
Reference in New Issue
Block a user