diff --git a/python/tests/mlx_tests.py b/python/tests/mlx_tests.py index 01ef407c32..9e8ed772eb 100644 --- a/python/tests/mlx_tests.py +++ b/python/tests/mlx_tests.py @@ -1,6 +1,7 @@ # Copyright © 2023 Apple Inc. import os +import platform import unittest from typing import Any, Callable, List, Tuple, Union @@ -9,6 +10,10 @@ import numpy as np class MLXTestCase(unittest.TestCase): + @property + def is_linux(self): + return platform.system() == "Linux" + def setUp(self): self.default = mx.default_device() device = os.getenv("DEVICE", None) diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 2e04477eb8..d291ca31ec 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1516,7 +1516,12 @@ class TestOps(mlx_tests.MLXTestCase): ) def test_tensordot(self): - for dtype in [mx.float16, mx.float32]: + # No fp16 matmuls on linux + if self.is_linux: + dtypes = [mx.float32] + else: + dtypes = [mx.float16, mx.float32] + for dtype in dtypes: with self.subTest(dtype=dtype): self.assertCmpNumpy( [(3, 4, 5), (4, 3, 2)],