From 34108e780ced2fa6c0978a36afc5a558927c71b4 Mon Sep 17 00:00:00 2001 From: dc-dc-dc Date: Tue, 2 Jan 2024 18:11:54 -0500 Subject: [PATCH] implemented changes --- mlx/ops.cpp | 14 ++++---------- tests/ops_tests.cpp | 15 ++------------- 2 files changed, 6 insertions(+), 23 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index b8b0fe1c9..7985c1bd9 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -2838,14 +2838,11 @@ array tensordot( auto x = a; auto y = b; for (int i = 0; i < dims[0].size(); i++) { - size_t xs = x.shape(dims[0].at(i)); - size_t ys = y.shape(dims[1].at(i)); - if (ys == 1) { - x = sum(x, dims[0].at(i), true, s); - } else if (xs == 1) { - y = sum(y, dims[1].at(i), true, s); + if (x.shape(dims[0].at(i)) == y.shape(dims[1].at(i))) { + csize *= x.shape(dims[0].at(i)); } else { - csize *= xs; + throw std::invalid_argument( + "[tensordot] a and b must have the same shape on the contracted axes."); } } @@ -2861,11 +2858,8 @@ array tensordot( } std::vector t1; - t1.reserve(a.ndim()); std::vector t2; - t2.reserve(b.ndim()); std::vector rshape; - rshape.reserve(a.ndim() + b.ndim() * dims[0].size()); int size1 = 1; int size2 = 1; for (int i = 0; i < a.ndim(); i++) { diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index f140e5246..ad767f525 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -2288,19 +2288,8 @@ TEST_CASE("tensordot") { CHECK(array_equal(z, expected).item()); x = reshape(arange(360.), {3, 4, 5, 6}); y = reshape(arange(360.), {6, 4, 5, 3}); - z = tensordot(x, y, {{2, 1, 3}, {1, 2, 0}}); - expected = array( - {1326270, - 1333410, - 1340550, - 3896670, - 3918210, - 3939750, - 6467070, - 6503010, - 6538950}, - {3, 3}); - CHECK(array_equal(z, expected).item()); + CHECK_THROWS_AS( + tensordot(x, y, {{2, 1, 3}, {1, 2, 0}}), std::invalid_argument); x = reshape(arange(60.), {3, 4, 5}); y = reshape(arange(120.), {4, 5, 6}); z = tensordot(x, y, 2);