mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-02 17:28:12 +08:00
@@ -1516,7 +1516,12 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
)
|
||||
|
||||
def test_tensordot(self):
|
||||
for dtype in [mx.float16, mx.float32]:
|
||||
# No fp16 matmuls on linux
|
||||
if self.is_linux:
|
||||
dtypes = [mx.float32]
|
||||
else:
|
||||
dtypes = [mx.float16, mx.float32]
|
||||
for dtype in dtypes:
|
||||
with self.subTest(dtype=dtype):
|
||||
self.assertCmpNumpy(
|
||||
[(3, 4, 5), (4, 3, 2)],
|
||||
|
||||
Reference in New Issue
Block a user