diff --git a/python/tests/mlx_tests.py b/python/tests/mlx_tests.py index dfbc835da..01ef407c3 100644 --- a/python/tests/mlx_tests.py +++ b/python/tests/mlx_tests.py @@ -2,7 +2,7 @@ import os import unittest -from typing import Any, Callable, List, Tuple +from typing import Any, Callable, List, Tuple, Union import mlx.core as mx import numpy as np @@ -21,7 +21,7 @@ class MLXTestCase(unittest.TestCase): def assertCmpNumpy( self, - shape: List[Tuple[int] | Any], + shape: List[Union[Tuple[int], Any]], mx_fn: Callable[..., mx.array], np_fn: Callable[..., np.array], atol=1e-2,