From 82a956c1d94872fd7030bc2ac6769c72661566e7 Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Fri, 6 Dec 2024 10:26:54 -0800 Subject: [PATCH] fix test --- python/tests/test_fast_sdpa.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index 73aa5b61d..528fc0274 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -128,10 +128,13 @@ class TestFastSDPA(mlx_tests.MLXTestCase): B = 1 H = 32 + dtypes = [mx.float32] + if self.is_apple_silcon: + dtypes.append(mx.float16) tests = product( [1, 7, 9, 32, 63, 67, 129, 2000], # sequence length [False, True], # gqa - [mx.float32, mx.float16], + dtypes, [4, 8], # bits ) for sequence_length, do_gqa, dtype, bits in tests: