diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 3848bd58c..bd58715e9 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -792,7 +792,8 @@ array quantized_scaled_dot_product_attention( int query_head_dim = queries.shape(-1); int L = queries.shape(2); - if (L > 1 && query_head_dim != 64 && query_head_dim != 128) { + bool compatible_head_dim = query_head_dim == 64 || query_head_dim == 128; + if (L > 1 || !compatible_head_dim || stream.device != Device::gpu) { if (needs_mask) { return fallback( {queries, diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index ad3262e7c..73aa5b61d 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -168,7 +168,6 @@ class TestFastSDPA(mlx_tests.MLXTestCase): if dtype == mx.float16: rtol = 1e-2 - # np.testing.assert_allclose(o_q, reference, rtol=rtol, atol=atol) self.assertTrue(mx.allclose(o_q, reference, rtol=rtol, atol=atol)) self.assertTrue(mx.allclose(o, reference, rtol=rtol, atol=atol))