use axes in tensordot (#525)

This commit is contained in:
Awni Hannun
2024-01-22 21:17:00 -08:00
committed by GitHub
parent f326dd8334
commit 98c37d3a22
2 changed files with 10 additions and 10 deletions

View File

@@ -1644,23 +1644,23 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertCmpNumpy(
[(3, 4, 5), (4, 3, 2)],
mx.tensordot,
lambda x, y, dims: np.tensordot(x, y, axes=dims),
np.tensordot,
dtype=dtype,
dims=([1, 0], [0, 1]),
axes=([1, 0], [0, 1]),
)
self.assertCmpNumpy(
[(3, 4, 5), (4, 5, 6)],
mx.tensordot,
lambda x, y, dims: np.tensordot(x, y, axes=dims),
np.tensordot,
dtype=dtype,
dims=2,
axes=2,
)
self.assertCmpNumpy(
[(3, 5, 4, 6), (6, 4, 5, 3)],
mx.tensordot,
lambda x, y, dims: np.tensordot(x, y, axes=dims),
np.tensordot,
dtype=dtype,
dims=([2, 1, 3], [1, 2, 0]),
axes=([2, 1, 3], [1, 2, 0]),
)
def test_inner(self):