mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Add Tensordot op (#344)
This commit is contained in:
@@ -1547,6 +1547,22 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
expected_3 = np.repeat(data_3, 2, axis=0)
|
||||
self.assertEqualArray(repeat_3, mx.array(expected_3))
|
||||
|
||||
def test_tensordot(self):
|
||||
x = mx.arange(60.0).reshape(3, 4, 5)
|
||||
y = mx.arange(24.0).reshape(4, 3, 2)
|
||||
z = mx.tensordot(x, y, dims=([1, 0], [0, 1]))
|
||||
self.assertEqualArray(z, mx.array(np.tensordot(x, y, axes=([1, 0], [0, 1]))))
|
||||
x = mx.random.normal((3, 4, 5))
|
||||
y = mx.random.normal((4, 5, 6))
|
||||
z = mx.tensordot(x, y, dims=2)
|
||||
self.assertEqualArray(z, mx.array(np.tensordot(x, y, axes=2)))
|
||||
x = mx.random.normal((3, 5, 4, 6))
|
||||
y = mx.random.normal((6, 4, 5, 3))
|
||||
z = mx.tensordot(x, y, dims=([2, 1, 3], [1, 2, 0]))
|
||||
self.assertEqualArray(
|
||||
z, mx.array(np.tensordot(x, y, axes=([2, 1, 3], [1, 2, 0])))
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user