diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 879090e95..95a05436c 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3400,20 +3400,20 @@ void init_ops(py::module_& m) { "a"_a, "b"_a, py::pos_only(), - "dims"_a = 2, + "axes"_a = 2, py::kw_only(), "stream"_a = none, R"pbdoc( - tensordot(a: array, b: array, /, dims: Union[int, List[List[int]]] = 2, *, stream: Union[None, Stream, Device] = None) -> array + tensordot(a: array, b: array, /, axes: Union[int, List[List[int]]] = 2, *, stream: Union[None, Stream, Device] = None) -> array Compute the tensor dot product along the specified axes. Args: a (array): Input array b (array): Input array - dims (int or list(list(int)), optional): The number of dimensions to + axes (int or list(list(int)), optional): The number of dimensions to sum over. If an integer is provided, then sum over the last - ``dims`` dimensions of ``a`` and the first ``dims`` dimensions of + ``axes`` dimensions of ``a`` and the first ``axes`` dimensions of ``b``. If a list of lists is provided, then sum over the corresponding dimensions of ``a`` and ``b``. (default: 2) diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 542f1540e..f4e31df80 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1644,23 +1644,23 @@ class TestOps(mlx_tests.MLXTestCase): self.assertCmpNumpy( [(3, 4, 5), (4, 3, 2)], mx.tensordot, - lambda x, y, dims: np.tensordot(x, y, axes=dims), + np.tensordot, dtype=dtype, - dims=([1, 0], [0, 1]), + axes=([1, 0], [0, 1]), ) self.assertCmpNumpy( [(3, 4, 5), (4, 5, 6)], mx.tensordot, - lambda x, y, dims: np.tensordot(x, y, axes=dims), + np.tensordot, dtype=dtype, - dims=2, + axes=2, ) self.assertCmpNumpy( [(3, 5, 4, 6), (6, 4, 5, 3)], mx.tensordot, - lambda x, y, dims: np.tensordot(x, y, axes=dims), + np.tensordot, dtype=dtype, - dims=([2, 1, 3], [1, 2, 0]), + axes=([2, 1, 3], [1, 2, 0]), ) def test_inner(self):