diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index 73aa5b61d..528fc0274 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -128,10 +128,13 @@ class TestFastSDPA(mlx_tests.MLXTestCase): B = 1 H = 32 + dtypes = [mx.float32] + if self.is_apple_silcon: + dtypes.append(mx.float16) tests = product( [1, 7, 9, 32, 63, 67, 129, 2000], # sequence length [False, True], # gqa - [mx.float32, mx.float16], + dtypes, [4, 8], # bits ) for sequence_length, do_gqa, dtype, bits in tests: