mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
nice tensordot for mlx c (#782)
This commit is contained in:
@@ -2554,14 +2554,13 @@ TEST_CASE("tile") {
|
||||
TEST_CASE("tensordot") {
|
||||
auto x = reshape(arange(60.), {3, 4, 5});
|
||||
auto y = reshape(arange(24.), {4, 3, 2});
|
||||
auto z = tensordot(x, y, {{1, 0}, {0, 1}});
|
||||
auto z = tensordot(x, y, {1, 0}, {0, 1});
|
||||
auto expected = array(
|
||||
{4400, 4730, 4532, 4874, 4664, 5018, 4796, 5162, 4928, 5306}, {5, 2});
|
||||
CHECK(array_equal(z, expected).item<bool>());
|
||||
x = reshape(arange(360.), {3, 4, 5, 6});
|
||||
y = reshape(arange(360.), {6, 4, 5, 3});
|
||||
CHECK_THROWS_AS(
|
||||
tensordot(x, y, {{2, 1, 3}, {1, 2, 0}}), std::invalid_argument);
|
||||
CHECK_THROWS_AS(tensordot(x, y, {2, 1, 3}, {1, 2, 0}), std::invalid_argument);
|
||||
x = reshape(arange(60.), {3, 4, 5});
|
||||
y = reshape(arange(120.), {4, 5, 6});
|
||||
z = tensordot(x, y, 2);
|
||||
|
Reference in New Issue
Block a user