mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-06 16:51:24 +08:00
use axes in tensordot (#525)
This commit is contained in:
@@ -1644,23 +1644,23 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertCmpNumpy(
|
||||
[(3, 4, 5), (4, 3, 2)],
|
||||
mx.tensordot,
|
||||
lambda x, y, dims: np.tensordot(x, y, axes=dims),
|
||||
np.tensordot,
|
||||
dtype=dtype,
|
||||
dims=([1, 0], [0, 1]),
|
||||
axes=([1, 0], [0, 1]),
|
||||
)
|
||||
self.assertCmpNumpy(
|
||||
[(3, 4, 5), (4, 5, 6)],
|
||||
mx.tensordot,
|
||||
lambda x, y, dims: np.tensordot(x, y, axes=dims),
|
||||
np.tensordot,
|
||||
dtype=dtype,
|
||||
dims=2,
|
||||
axes=2,
|
||||
)
|
||||
self.assertCmpNumpy(
|
||||
[(3, 5, 4, 6), (6, 4, 5, 3)],
|
||||
mx.tensordot,
|
||||
lambda x, y, dims: np.tensordot(x, y, axes=dims),
|
||||
np.tensordot,
|
||||
dtype=dtype,
|
||||
dims=([2, 1, 3], [1, 2, 0]),
|
||||
axes=([2, 1, 3], [1, 2, 0]),
|
||||
)
|
||||
|
||||
def test_inner(self):
|
||||
|
Reference in New Issue
Block a user