diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 8dd4221b3..094756ac5 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -172,6 +172,7 @@ void sdpa_vector( std::string hash_name = kname; hash_name += has_mask ? "_mask" : "_nomask"; hash_name += query_transposed ? "_qt" : "_qnt"; + hash_name += do_causal ? "_c" : "_nc"; // Get the kernel auto& compute_encoder = d.get_command_encoder(s.index); @@ -268,6 +269,7 @@ void sdpa_vector_2pass( std::string hash_name = kname; hash_name += has_mask ? "_mask" : "_nomask"; hash_name += query_transposed ? "_qt" : "_qnt"; + hash_name += do_causal ? "_c" : "_nc"; // Get the kernel auto& compute_encoder = d.get_command_encoder(s.index); diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index b767fbc8f..e7b7e5ac3 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -347,6 +347,7 @@ class TestFastSDPA(mlx_tests.MLXTestCase): ) masks = [ + None, mx.array(True), mx.array([True] * (L - 10) + [False] * 10), mx.random.uniform(shape=(Nq, 1, L)) > 0.2, @@ -392,7 +393,7 @@ class TestFastSDPA(mlx_tests.MLXTestCase): def test_fast_sdpa_few_query(self): D = 64 L = 43 - Lq = 4 + Lq = 8 Nq = 8 Nkv = 1 scale = 1.0 @@ -403,6 +404,7 @@ class TestFastSDPA(mlx_tests.MLXTestCase): v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D)) masks = [ + None, mx.array(True), mx.array([True] * (L - 10) + [False] * 10), mx.random.uniform(shape=(Nq, 1, L)) > 0.2, @@ -428,6 +430,7 @@ class TestFastSDPA(mlx_tests.MLXTestCase): v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D)) masks = [ + None, mx.array(True), mx.array([True] * (L - 10) + [False] * 10), mx.random.uniform(shape=(Nq, 1, L)) > 0.2,