diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 014707b38..b8b0fe1c9 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -2793,4 +2793,104 @@ array dequantize( return w_full; } +array tensordot( + const array& a, + const array& b, + const int dims /* = 2 */, + StreamOrDevice s /* = {} */ +) { + if (dims < 0) { + throw std::invalid_argument( + "[tensordot] dims must be greater or equal to 0."); + } + if (dims > std::min(a.ndim(), b.ndim())) { + throw std::invalid_argument( + "[tensordot] dims must be less than the number of dimensions of a and b."); + } + std::vector adims; + std::vector bdims; + for (int i = 0; i < dims; i++) { + bdims.emplace_back(i); + adims.emplace_back(-dims + i); + } + return tensordot(a, b, {adims, bdims}, s); +} + +array tensordot( + const array& a, + const array& b, + const std::vector>& dims, + StreamOrDevice s /* = {} */ +) { + if (dims.size() != 2) { + throw std::invalid_argument( + "[tensordot] dims must be a vector of two vectors."); + } + if (dims[0].size() != dims[1].size()) { + throw std::invalid_argument( + "[tensordot] dims[0] and dims[1] must have the same number of dimensions."); + } + if (a.dtype() != b.dtype()) { + throw std::invalid_argument( + "[tensordot] a and b must have the same dtype."); + } + int csize = 1; + auto x = a; + auto y = b; + for (int i = 0; i < dims[0].size(); i++) { + size_t xs = x.shape(dims[0].at(i)); + size_t ys = y.shape(dims[1].at(i)); + if (ys == 1) { + x = sum(x, dims[0].at(i), true, s); + } else if (xs == 1) { + y = sum(y, dims[1].at(i), true, s); + } else { + csize *= xs; + } + } + + std::vector cdims1(x.ndim(), false); + std::vector cdims2(y.ndim(), false); + for (const auto n : dims[0]) { + int n_ = (n < 0) ? n + x.ndim() : n; + cdims1[n_] = true; + } + for (const auto n : dims[1]) { + int n_ = (n < 0) ? n + y.ndim() : n; + cdims2[n_] = true; + } + + std::vector t1; + t1.reserve(a.ndim()); + std::vector t2; + t2.reserve(b.ndim()); + std::vector rshape; + rshape.reserve(a.ndim() + b.ndim() * dims[0].size()); + int size1 = 1; + int size2 = 1; + for (int i = 0; i < a.ndim(); i++) { + if (!cdims1[i]) { + t1.emplace_back(i); + size1 *= a.shape(i); + rshape.emplace_back(a.shape(i)); + } + } + for (const auto x : dims[0]) { + t1.emplace_back(x); + } + for (const auto x : dims[1]) { + t2.emplace_back(x); + } + for (int i = 0; i < b.ndim(); i++) { + if (!cdims2[i]) { + t2.emplace_back(i); + size2 *= b.shape(i); + rshape.emplace_back(b.shape(i)); + } + } + x = reshape(transpose(x, t1, s), {size1, csize}, s); + y = reshape(transpose(y, t2, s), {csize, size2}, s); + return reshape(matmul(x, y, s), rshape, s); +} + } // namespace mlx::core diff --git a/mlx/ops.h b/mlx/ops.h index c888c80cd..c556e013a 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1061,6 +1061,19 @@ array dequantize( int bits = 4, StreamOrDevice s = {}); +/** TensorDot returns a contraction of a and b over multiple dimensions. */ +array tensordot( + const array& a, + const array& b, + const int dims = 2, + StreamOrDevice s = {}); + +array tensordot( + const array& a, + const array& b, + const std::vector>& dims, + StreamOrDevice s = {}); + /** Load array map from .safetensors file format */ std::unordered_map load_safetensors( std::shared_ptr in_stream, diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index f6443bc7e..f140e5246 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -2277,4 +2277,52 @@ TEST_CASE("test repeat") { // negative repeats CHECK_THROWS_AS(repeat(data_3, -3, 0), std::invalid_argument); +} + +TEST_CASE("tensordot") { + auto x = reshape(arange(60.), {3, 4, 5}); + auto y = reshape(arange(24.), {4, 3, 2}); + auto z = tensordot(x, y, {{1, 0}, {0, 1}}); + auto expected = array( + {4400, 4730, 4532, 4874, 4664, 5018, 4796, 5162, 4928, 5306}, {5, 2}); + CHECK(array_equal(z, expected).item()); + x = reshape(arange(360.), {3, 4, 5, 6}); + y = reshape(arange(360.), {6, 4, 5, 3}); + z = tensordot(x, y, {{2, 1, 3}, {1, 2, 0}}); + expected = array( + {1326270, + 1333410, + 1340550, + 3896670, + 3918210, + 3939750, + 6467070, + 6503010, + 6538950}, + {3, 3}); + CHECK(array_equal(z, expected).item()); + x = reshape(arange(60.), {3, 4, 5}); + y = reshape(arange(120.), {4, 5, 6}); + z = tensordot(x, y, 2); + expected = array( + {14820., + 15010., + 15200., + 15390., + 15580., + 15770., + 37620., + 38210., + 38800., + 39390., + 39980., + 40570., + 60420., + 61410., + 62400., + 63390., + 64380., + 65370.}, + {3, 6}); + CHECK(array_equal(z, expected).item()); } \ No newline at end of file