diff --git a/mlx/fast.cpp b/mlx/fast.cpp index ac3cfe042..bcc5ccbd3 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -750,6 +750,8 @@ array scaled_dot_product_attention( msg << "[scaled_dot_product_attention] Mask type must promote to output type. " << final_type << "."; throw std::invalid_argument(msg.str()); + } else if (!has_bool_mask) { + mask_arr = astype(mask_arr, final_type, stream); } // Broadcast mask auto mask_shape = queries.shape(); diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index 78e03159f..612c284aa 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -543,6 +543,22 @@ class TestSDPA(mlx_tests.MLXTestCase): out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask) self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + def test_sdpa_prommote_mask(self): + mask = mx.array(2.0, mx.bfloat16) + D = 64 + Nq = 4 + Nkv = 1 + scale = 1.0 + L = 256 + + mx.random.seed(0) + q = 5e-1 * mx.random.normal(shape=(1, Nq, L, D)) + k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D)) + v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D)) + ref = mlx_primitives_sdpa(q, k, v, scale, mask=mask) + out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask) + self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + if __name__ == "__main__": unittest.main(failfast=True)