nice tensordot for mlx c (#782)

This commit is contained in:
Awni Hannun
2024-03-04 09:51:02 -08:00
committed by GitHub
parent 6a665ea6ed
commit 5121f028d9
4 changed files with 32 additions and 33 deletions

View File

@@ -2554,14 +2554,13 @@ TEST_CASE("tile") {
TEST_CASE("tensordot") {
auto x = reshape(arange(60.), {3, 4, 5});
auto y = reshape(arange(24.), {4, 3, 2});
auto z = tensordot(x, y, {{1, 0}, {0, 1}});
auto z = tensordot(x, y, {1, 0}, {0, 1});
auto expected = array(
{4400, 4730, 4532, 4874, 4664, 5018, 4796, 5162, 4928, 5306}, {5, 2});
CHECK(array_equal(z, expected).item<bool>());
x = reshape(arange(360.), {3, 4, 5, 6});
y = reshape(arange(360.), {6, 4, 5, 3});
CHECK_THROWS_AS(
tensordot(x, y, {{2, 1, 3}, {1, 2, 0}}), std::invalid_argument);
CHECK_THROWS_AS(tensordot(x, y, {2, 1, 3}, {1, 2, 0}), std::invalid_argument);
x = reshape(arange(60.), {3, 4, 5});
y = reshape(arange(120.), {4, 5, 6});
z = tensordot(x, y, 2);