From 3fb3385a1ae3e880fd28679f25211348ce8ff95a Mon Sep 17 00:00:00 2001 From: dc-dc-dc Date: Tue, 2 Jan 2024 13:16:18 -0500 Subject: [PATCH] python bindings --- python/src/ops.cpp | 22 ++++++++++++++++++++++ python/src/utils.h | 2 ++ python/tests/test_ops.py | 16 ++++++++++++++++ 3 files changed, 40 insertions(+) diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 1f60c6444..20dbe181a 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3194,4 +3194,26 @@ void init_ops(py::module_& m) { Returns: result (array): The dequantized version of ``w`` )pbdoc"); + m.def( + "tensordot", + [](const array& a, + const array& b, + const IntOrIntVec& dims, + StreamOrDevice s) { + if (std::holds_alternative(dims)) { + return tensordot(a, b, 2, s); + } else if (auto pv = std::get_if(&dims); pv) { + return tensordot(a, b, *pv, s); + } else { + return tensordot( + a, b, std::get>>(dims), s); + } + }, + "a"_a, + "b"_a, + py::pos_only(), + "dims"_a = 2, + py::kw_only(), + "stream"_a = none, + ""); } diff --git a/python/src/utils.h b/python/src/utils.h index 5ac878979..86da9f6db 100644 --- a/python/src/utils.h +++ b/python/src/utils.h @@ -15,6 +15,8 @@ namespace py = pybind11; using namespace mlx::core; using IntOrVec = std::variant>; +using IntOrIntVec = + std::variant>>; using ScalarOrArray = std:: variant, py::object>; static constexpr std::monostate none{}; diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 782249b56..65de09634 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1547,6 +1547,22 @@ class TestOps(mlx_tests.MLXTestCase): expected_3 = np.repeat(data_3, 2, axis=0) self.assertEqualArray(repeat_3, mx.array(expected_3)) + def test_tensordot(self): + x = mx.arange(60.0).reshape(3, 4, 5) + y = mx.arange(24.0).reshape(4, 3, 2) + z = mx.tensordot(x, y, dims=([1, 0], [0, 1])) + self.assertEqualArray(z, mx.array(np.tensordot(x, y, axes=([1, 0], [0, 1])))) + x = mx.random.normal((3, 4, 5)) + y = mx.random.normal((4, 5, 6)) + z = mx.tensordot(x, y, dims=2) + self.assertEqualArray(z, mx.array(np.tensordot(x, y, axes=2))) + x = mx.random.normal((3, 5, 4, 6)) + y = mx.random.normal((6, 4, 5, 3)) + z = mx.tensordot(x, y, dims=([2, 1, 3], [1, 2, 0])) + self.assertEqualArray( + z, mx.array(np.tensordot(x, y, axes=([2, 1, 3], [1, 2, 0]))) + ) + if __name__ == "__main__": unittest.main()