* fix ci

* check for linux for fp16
This commit is contained in:
Awni Hannun
2024-01-04 06:33:08 -08:00
committed by GitHub
parent d2467c320d
commit d752f8e142
2 changed files with 11 additions and 1 deletions

View File

@@ -1516,7 +1516,12 @@ class TestOps(mlx_tests.MLXTestCase):
)
def test_tensordot(self):
for dtype in [mx.float16, mx.float32]:
# No fp16 matmuls on linux
if self.is_linux:
dtypes = [mx.float32]
else:
dtypes = [mx.float16, mx.float32]
for dtype in dtypes:
with self.subTest(dtype=dtype):
self.assertCmpNumpy(
[(3, 4, 5), (4, 3, 2)],