* fix ci

* check for linux for fp16
This commit is contained in:
Awni Hannun 2024-01-04 06:33:08 -08:00 committed by GitHub
parent d2467c320d
commit d752f8e142
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 1 deletions

View File

@ -1,6 +1,7 @@
# Copyright © 2023 Apple Inc.
import os
import platform
import unittest
from typing import Any, Callable, List, Tuple, Union
@ -9,6 +10,10 @@ import numpy as np
class MLXTestCase(unittest.TestCase):
@property
def is_linux(self):
return platform.system() == "Linux"
def setUp(self):
self.default = mx.default_device()
device = os.getenv("DEVICE", None)

View File

@ -1516,7 +1516,12 @@ class TestOps(mlx_tests.MLXTestCase):
)
def test_tensordot(self):
for dtype in [mx.float16, mx.float32]:
# No fp16 matmuls on linux
if self.is_linux:
dtypes = [mx.float32]
else:
dtypes = [mx.float16, mx.float32]
for dtype in dtypes:
with self.subTest(dtype=dtype):
self.assertCmpNumpy(
[(3, 4, 5), (4, 3, 2)],