use axes in tensordot (#525)

This commit is contained in:
Awni Hannun 2024-01-22 21:17:00 -08:00 committed by GitHub
parent f326dd8334
commit 98c37d3a22
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 10 additions and 10 deletions

View File

@ -3400,20 +3400,20 @@ void init_ops(py::module_& m) {
"a"_a, "a"_a,
"b"_a, "b"_a,
py::pos_only(), py::pos_only(),
"dims"_a = 2, "axes"_a = 2,
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
tensordot(a: array, b: array, /, dims: Union[int, List[List[int]]] = 2, *, stream: Union[None, Stream, Device] = None) -> array tensordot(a: array, b: array, /, axes: Union[int, List[List[int]]] = 2, *, stream: Union[None, Stream, Device] = None) -> array
Compute the tensor dot product along the specified axes. Compute the tensor dot product along the specified axes.
Args: Args:
a (array): Input array a (array): Input array
b (array): Input array b (array): Input array
dims (int or list(list(int)), optional): The number of dimensions to axes (int or list(list(int)), optional): The number of dimensions to
sum over. If an integer is provided, then sum over the last sum over. If an integer is provided, then sum over the last
``dims`` dimensions of ``a`` and the first ``dims`` dimensions of ``axes`` dimensions of ``a`` and the first ``axes`` dimensions of
``b``. If a list of lists is provided, then sum over the ``b``. If a list of lists is provided, then sum over the
corresponding dimensions of ``a`` and ``b``. (default: 2) corresponding dimensions of ``a`` and ``b``. (default: 2)

View File

@ -1644,23 +1644,23 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertCmpNumpy( self.assertCmpNumpy(
[(3, 4, 5), (4, 3, 2)], [(3, 4, 5), (4, 3, 2)],
mx.tensordot, mx.tensordot,
lambda x, y, dims: np.tensordot(x, y, axes=dims), np.tensordot,
dtype=dtype, dtype=dtype,
dims=([1, 0], [0, 1]), axes=([1, 0], [0, 1]),
) )
self.assertCmpNumpy( self.assertCmpNumpy(
[(3, 4, 5), (4, 5, 6)], [(3, 4, 5), (4, 5, 6)],
mx.tensordot, mx.tensordot,
lambda x, y, dims: np.tensordot(x, y, axes=dims), np.tensordot,
dtype=dtype, dtype=dtype,
dims=2, axes=2,
) )
self.assertCmpNumpy( self.assertCmpNumpy(
[(3, 5, 4, 6), (6, 4, 5, 3)], [(3, 5, 4, 6), (6, 4, 5, 3)],
mx.tensordot, mx.tensordot,
lambda x, y, dims: np.tensordot(x, y, axes=dims), np.tensordot,
dtype=dtype, dtype=dtype,
dims=([2, 1, 3], [1, 2, 0]), axes=([2, 1, 3], [1, 2, 0]),
) )
def test_inner(self): def test_inner(self):