diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 398f136b2..15648e6e0 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -10,7 +10,7 @@ MLX was developed with contributions from the following individuals: - Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. - Juarez Bochi: Fixed bug in cross attention. - Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example. -- Diogo Da Cruz: Added tri, tril, triu and safetensor support +- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot` and safetensor support # Third-Party Software diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index 6c31b54ec..4e399524e 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -104,6 +104,7 @@ Operations take_along_axis tan tanh + tensordot transpose tri tril diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 014707b38..e1f593aba 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -2793,4 +2793,94 @@ array dequantize( return w_full; } +array tensordot( + const array& a, + const array& b, + const int dims /* = 2 */, + StreamOrDevice s /* = {} */ +) { + if (dims < 0) { + throw std::invalid_argument( + "[tensordot] dims must be greater or equal to 0."); + } + if (dims > std::min(a.ndim(), b.ndim())) { + throw std::invalid_argument( + "[tensordot] dims must be less than the number of dimensions of a and b."); + } + std::vector adims; + std::vector bdims; + for (int i = 0; i < dims; i++) { + bdims.emplace_back(i); + adims.emplace_back(-dims + i); + } + return tensordot(a, b, {adims, bdims}, s); +} + +array tensordot( + const array& a, + const array& b, + const std::pair, std::vector>& dims, + StreamOrDevice s /* = {} */ +) { + if (dims.first.size() != dims.second.size()) { + throw std::invalid_argument( + "[tensordot] dims[0] and dims[1] must have the same number of dimensions."); + } + if (a.dtype() != b.dtype()) { + throw std::invalid_argument( + "[tensordot] a and b must have the same dtype."); + } + int csize = 1; + auto x = a; + auto y = b; + for (int i = 0; i < dims.first.size(); i++) { + if (x.shape(dims.first.at(i)) == y.shape(dims.second.at(i))) { + csize *= x.shape(dims.first.at(i)); + } else { + throw std::invalid_argument( + "[tensordot] a and b must have the same shape on the contracted axes."); + } + } + + std::vector cdims1(x.ndim(), false); + std::vector cdims2(y.ndim(), false); + for (const auto n : dims.first) { + int n_ = (n < 0) ? n + x.ndim() : n; + cdims1[n_] = true; + } + for (const auto n : dims.second) { + int n_ = (n < 0) ? n + y.ndim() : n; + cdims2[n_] = true; + } + + std::vector t1; + std::vector t2; + std::vector rshape; + int size1 = 1; + int size2 = 1; + for (int i = 0; i < a.ndim(); i++) { + if (!cdims1[i]) { + t1.emplace_back(i); + size1 *= a.shape(i); + rshape.emplace_back(a.shape(i)); + } + } + for (const auto x : dims.first) { + t1.emplace_back(x); + } + for (const auto x : dims.second) { + t2.emplace_back(x); + } + for (int i = 0; i < b.ndim(); i++) { + if (!cdims2[i]) { + t2.emplace_back(i); + size2 *= b.shape(i); + rshape.emplace_back(b.shape(i)); + } + } + x = reshape(transpose(x, t1, s), {size1, csize}, s); + y = reshape(transpose(y, t2, s), {csize, size2}, s); + return reshape(matmul(x, y, s), rshape, s); +} + } // namespace mlx::core diff --git a/mlx/ops.h b/mlx/ops.h index c888c80cd..6516ad008 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1061,6 +1061,19 @@ array dequantize( int bits = 4, StreamOrDevice s = {}); +/** TensorDot returns a contraction of a and b over multiple dimensions. */ +array tensordot( + const array& a, + const array& b, + const int dims = 2, + StreamOrDevice s = {}); + +array tensordot( + const array& a, + const array& b, + const std::pair, std::vector>& dims, + StreamOrDevice s = {}); + /** Load array map from .safetensors file format */ std::unordered_map load_safetensors( std::shared_ptr in_stream, diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 1f60c6444..c152eeb97 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3194,4 +3194,44 @@ void init_ops(py::module_& m) { Returns: result (array): The dequantized version of ``w`` )pbdoc"); + m.def( + "tensordot", + [](const array& a, + const array& b, + const std::variant>>& dims, + StreamOrDevice s) { + if (auto pv = std::get_if(&dims); pv) { + return tensordot(a, b, *pv, s); + } else { + auto x = std::get>>(dims); + if (x.size() != 2) { + throw std::invalid_argument( + "[tensordot] dims must be a list of two lists."); + } + return tensordot(a, b, {x[0], x[1]}, s); + } + }, + "a"_a, + "b"_a, + py::pos_only(), + "dims"_a = 2, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + tensordot(a: array, b: array, /, dims: Union[int, List[List[int]]] = 2, *, stream: Union[None, Stream, Device] = None) -> array + + Compute the tensor dot product along the specified axes. + + Args: + a (array): Input array + b (array): Input array + dims (int or list(list(int)), optional): The number of dimensions to + sum over. If an integer is provided, then sum over the last + ``dims`` dimensions of ``a`` and the first ``dims`` dimensions of + ``b``. If a list of lists is provided, then sum over the + corresponding dimensions of ``a`` and ``b``. (default: 2) + + Returns: + result (array): The tensor dot product. + )pbdoc"); } diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 782249b56..65de09634 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1547,6 +1547,22 @@ class TestOps(mlx_tests.MLXTestCase): expected_3 = np.repeat(data_3, 2, axis=0) self.assertEqualArray(repeat_3, mx.array(expected_3)) + def test_tensordot(self): + x = mx.arange(60.0).reshape(3, 4, 5) + y = mx.arange(24.0).reshape(4, 3, 2) + z = mx.tensordot(x, y, dims=([1, 0], [0, 1])) + self.assertEqualArray(z, mx.array(np.tensordot(x, y, axes=([1, 0], [0, 1])))) + x = mx.random.normal((3, 4, 5)) + y = mx.random.normal((4, 5, 6)) + z = mx.tensordot(x, y, dims=2) + self.assertEqualArray(z, mx.array(np.tensordot(x, y, axes=2))) + x = mx.random.normal((3, 5, 4, 6)) + y = mx.random.normal((6, 4, 5, 3)) + z = mx.tensordot(x, y, dims=([2, 1, 3], [1, 2, 0])) + self.assertEqualArray( + z, mx.array(np.tensordot(x, y, axes=([2, 1, 3], [1, 2, 0]))) + ) + if __name__ == "__main__": unittest.main() diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index f6443bc7e..ad767f525 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -2277,4 +2277,41 @@ TEST_CASE("test repeat") { // negative repeats CHECK_THROWS_AS(repeat(data_3, -3, 0), std::invalid_argument); +} + +TEST_CASE("tensordot") { + auto x = reshape(arange(60.), {3, 4, 5}); + auto y = reshape(arange(24.), {4, 3, 2}); + auto z = tensordot(x, y, {{1, 0}, {0, 1}}); + auto expected = array( + {4400, 4730, 4532, 4874, 4664, 5018, 4796, 5162, 4928, 5306}, {5, 2}); + CHECK(array_equal(z, expected).item()); + x = reshape(arange(360.), {3, 4, 5, 6}); + y = reshape(arange(360.), {6, 4, 5, 3}); + CHECK_THROWS_AS( + tensordot(x, y, {{2, 1, 3}, {1, 2, 0}}), std::invalid_argument); + x = reshape(arange(60.), {3, 4, 5}); + y = reshape(arange(120.), {4, 5, 6}); + z = tensordot(x, y, 2); + expected = array( + {14820., + 15010., + 15200., + 15390., + 15580., + 15770., + 37620., + 38210., + 38800., + 39390., + 39980., + 40570., + 60420., + 61410., + 62400., + 63390., + 64380., + 65370.}, + {3, 6}); + CHECK(array_equal(z, expected).item()); } \ No newline at end of file