From b42d13ec8443c299d32b4b254161413bfaa72acb Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Thu, 20 Mar 2025 14:25:38 -0700 Subject: [PATCH] Update attention tests to show diff, disable array masks (#1978) --- mlx/fast.cpp | 4 ++-- python/tests/test_fast_sdpa.py | 9 ++++----- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 342078a24..ed0d9fbe5 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -748,8 +748,8 @@ array scaled_dot_product_attention( (query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128); const bool sdpa_vector_supported_mask = (!has_mask || has_bool_mask); - const bool sdpa_full_supported_mask = !has_mask || has_arr_mask || - (query_sequence_length <= key_sequence_length && do_causal); + const bool sdpa_full_supported_mask = + !has_mask || (query_sequence_length <= key_sequence_length && do_causal); const bool supports_sdpa_full = query_sequence_length >= threshold && sdpa_full_supported_mask && sdpa_full_supported_head_dim && diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index a269847de..4ea573564 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -72,8 +72,8 @@ def prepare_inputs(B, qL, kL, D, qH, kH, mask, transpose, dtype): scale = 1.0 / math.sqrt(D) - q_np = np.random.normal(0.0, 1.0, shape_q).astype(np_dtype) - k_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype) + q_np = np.random.normal(0.0, 0.5, shape_q).astype(np_dtype) + k_np = np.random.normal(0.0, 0.5, shape_kv).astype(np_dtype) v_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype) q_mx = mx.array(q_np) @@ -524,9 +524,8 @@ class TestSDPA(mlx_tests.MLXTestCase): list(out_ref.shape), list(out_fst.shape) ) - self.assertTrue( - mx.allclose(out_fst, out_ref, atol=atol, rtol=atol) - ) + diff = mx.abs(out_fst - out_ref) - atol * mx.abs(out_ref) + self.assertLessEqual(mx.max(diff).item(), atol) if __name__ == "__main__":