promote mask when needed (#1998)

This commit is contained in:
Awni Hannun 2025-03-23 19:58:28 -07:00 committed by GitHub
parent f018e248cd
commit a84cc0123f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 18 additions and 0 deletions

View File

@ -750,6 +750,8 @@ array scaled_dot_product_attention(
msg << "[scaled_dot_product_attention] Mask type must promote to output type. " msg << "[scaled_dot_product_attention] Mask type must promote to output type. "
<< final_type << "."; << final_type << ".";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} else if (!has_bool_mask) {
mask_arr = astype(mask_arr, final_type, stream);
} }
// Broadcast mask // Broadcast mask
auto mask_shape = queries.shape(); auto mask_shape = queries.shape();

View File

@ -543,6 +543,22 @@ class TestSDPA(mlx_tests.MLXTestCase):
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask) out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
def test_sdpa_prommote_mask(self):
mask = mx.array(2.0, mx.bfloat16)
D = 64
Nq = 4
Nkv = 1
scale = 1.0
L = 256
mx.random.seed(0)
q = 5e-1 * mx.random.normal(shape=(1, Nq, L, D))
k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
ref = mlx_primitives_sdpa(q, k, v, scale, mask=mask)
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main(failfast=True) unittest.main(failfast=True)