From 859ae15a5401d13c2a715aa7b6e1e6f062f00bb8 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 4 Mar 2024 23:02:27 -0800 Subject: [PATCH] Fix test (#785) --- python/tests/test_fast_sdpa.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index 8c8a599f4..4be45a552 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -32,12 +32,8 @@ def mlx_primitives_sdpa_with_gqa(q, k, v, scale): return mlx_primitives_sdpa(q, k, v, scale) -class TestFastInferenceSDPA(mlx_tests.MLXTestCase): - @property - def dtypes(self): - return ["float32", "float16"] if mx.metal.is_available() else ["float32"] - - def test_fast_inference_sdpa(self): +class TestFastSDPA(mlx_tests.MLXTestCase): + def test_fast_sdpa(self): # Not yet supported: # * K pre-transposed in kernel, V pre-transposed in kernel @@ -65,9 +61,13 @@ class TestFastInferenceSDPA(mlx_tests.MLXTestCase): B = 1 H = 32 + dtypes = [np.float32] + if not self.is_linux: + dtypes.append(np.half) + for SEQUENCE_LENGTH in [1, 7, 9, 32, 63, 67, 129, 400, 2000]: for DO_GQA in [0, 1]: - for DTYPE in [np.float32, np.half]: + for DTYPE in dtypes: n_kv_heads = 8 if DO_GQA else 32 q_npy = np.random.normal(0.0, 1.0, (B, H, R, Dk)).astype(DTYPE) k_npy = np.random.normal(