diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 04b2a8e5b..e0e11f4c1 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3190,42 +3190,41 @@ array dequantize( array tensordot( const array& a, const array& b, - const int dims /* = 2 */, + const int axis /* = 2 */, StreamOrDevice s /* = {} */ ) { - if (dims < 0) { + if (axis < 0) { throw std::invalid_argument( - "[tensordot] dims must be greater or equal to 0."); + "[tensordot] axis must be greater or equal to 0."); } - if (dims > std::min(a.ndim(), b.ndim())) { + if (axis > std::min(a.ndim(), b.ndim())) { throw std::invalid_argument( - "[tensordot] dims must be less than the number of dimensions of a and b."); + "[tensordot] axis must be less than the number of dimensions of a and b."); } std::vector adims; std::vector bdims; - for (int i = 0; i < dims; i++) { + for (int i = 0; i < axis; i++) { bdims.emplace_back(i); - adims.emplace_back(-dims + i); + adims.emplace_back(i - axis); } - return tensordot(a, b, {adims, bdims}, s); + return tensordot(a, b, {adims}, {bdims}, s); } array tensordot( const array& a, const array& b, - const std::pair, std::vector>& dims, - StreamOrDevice s /* = {} */ -) { - if (dims.first.size() != dims.second.size()) { - throw std::invalid_argument( - "[tensordot] dims[0] and dims[1] must have the same number of dimensions."); + const std::vector& axes_a, + const std::vector& axes_b, + StreamOrDevice s /* = {} */) { + if (axes_a.size() != axes_b.size()) { + throw std::invalid_argument("[tensordot] axes must have the same size."); } int csize = 1; auto x = a; auto y = b; - for (int i = 0; i < dims.first.size(); i++) { - if (x.shape(dims.first.at(i)) == y.shape(dims.second.at(i))) { - csize *= x.shape(dims.first.at(i)); + for (int i = 0; i < axes_a.size(); i++) { + if (x.shape(axes_a.at(i)) == y.shape(axes_b.at(i))) { + csize *= x.shape(axes_a.at(i)); } else { throw std::invalid_argument( "[tensordot] a and b must have the same shape on the contracted axes."); @@ -3234,11 +3233,11 @@ array tensordot( std::vector cdims1(x.ndim(), false); std::vector cdims2(y.ndim(), false); - for (const auto n : dims.first) { + for (const auto n : axes_a) { int n_ = (n < 0) ? n + x.ndim() : n; cdims1[n_] = true; } - for (const auto n : dims.second) { + for (const auto n : axes_b) { int n_ = (n < 0) ? n + y.ndim() : n; cdims2[n_] = true; } @@ -3255,10 +3254,10 @@ array tensordot( rshape.emplace_back(a.shape(i)); } } - for (const auto x : dims.first) { + for (const auto x : axes_a) { t1.emplace_back(x); } - for (const auto x : dims.second) { + for (const auto x : axes_b) { t2.emplace_back(x); } for (int i = 0; i < b.ndim(); i++) { @@ -3287,7 +3286,7 @@ array inner(const array& a, const array& b, StreamOrDevice s /* = {} */) { "[inner] a and b must have the same last dimension."); } - return tensordot(a, b, {{-1}, {-1}}, s); + return tensordot(a, b, {-1}, {-1}, s); } /** Compute D = beta * C + alpha * (A @ B) */ diff --git a/mlx/ops.h b/mlx/ops.h index b24c3971e..e5aa17c52 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1110,17 +1110,18 @@ array dequantize( int bits = 4, StreamOrDevice s = {}); -/** TensorDot returns a contraction of a and b over multiple dimensions. */ +/** Returns a contraction of a and b over multiple dimensions. */ array tensordot( const array& a, const array& b, - const int dims = 2, + const int axis = 2, StreamOrDevice s = {}); array tensordot( const array& a, const array& b, - const std::pair, std::vector>& dims, + const std::vector& axes_a, + const std::vector& axes_b, StreamOrDevice s = {}); /** Compute the outer product of two vectors. */ diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 2491939d5..8209169f2 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3555,17 +3555,17 @@ void init_ops(py::module_& m) { "tensordot", [](const array& a, const array& b, - const std::variant>>& dims, + const std::variant>>& axes, StreamOrDevice s) { - if (auto pv = std::get_if(&dims); pv) { + if (auto pv = std::get_if(&axes); pv) { return tensordot(a, b, *pv, s); } else { - auto x = std::get>>(dims); + auto& x = std::get>>(axes); if (x.size() != 2) { throw std::invalid_argument( - "[tensordot] dims must be a list of two lists."); + "[tensordot] axes must be a list of two lists."); } - return tensordot(a, b, {x[0], x[1]}, s); + return tensordot(a, b, x[0], x[1], s); } }, "a"_a, diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 4b05425e7..863b3fd72 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -2554,14 +2554,13 @@ TEST_CASE("tile") { TEST_CASE("tensordot") { auto x = reshape(arange(60.), {3, 4, 5}); auto y = reshape(arange(24.), {4, 3, 2}); - auto z = tensordot(x, y, {{1, 0}, {0, 1}}); + auto z = tensordot(x, y, {1, 0}, {0, 1}); auto expected = array( {4400, 4730, 4532, 4874, 4664, 5018, 4796, 5162, 4928, 5306}, {5, 2}); CHECK(array_equal(z, expected).item()); x = reshape(arange(360.), {3, 4, 5, 6}); y = reshape(arange(360.), {6, 4, 5, 3}); - CHECK_THROWS_AS( - tensordot(x, y, {{2, 1, 3}, {1, 2, 0}}), std::invalid_argument); + CHECK_THROWS_AS(tensordot(x, y, {2, 1, 3}, {1, 2, 0}), std::invalid_argument); x = reshape(arange(60.), {3, 4, 5}); y = reshape(arange(120.), {4, 5, 6}); z = tensordot(x, y, 2);