cpu fallback

This commit is contained in:
Alex Barron 2024-12-06 01:22:50 -08:00
parent c89ddf62b4
commit 769704653a
2 changed files with 2 additions and 2 deletions

View File

@ -792,7 +792,8 @@ array quantized_scaled_dot_product_attention(
int query_head_dim = queries.shape(-1); int query_head_dim = queries.shape(-1);
int L = queries.shape(2); int L = queries.shape(2);
if (L > 1 && query_head_dim != 64 && query_head_dim != 128) { bool compatible_head_dim = query_head_dim == 64 || query_head_dim == 128;
if (L > 1 || !compatible_head_dim || stream.device != Device::gpu) {
if (needs_mask) { if (needs_mask) {
return fallback( return fallback(
{queries, {queries,

View File

@ -168,7 +168,6 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
if dtype == mx.float16: if dtype == mx.float16:
rtol = 1e-2 rtol = 1e-2
# np.testing.assert_allclose(o_q, reference, rtol=rtol, atol=atol)
self.assertTrue(mx.allclose(o_q, reference, rtol=rtol, atol=atol)) self.assertTrue(mx.allclose(o_q, reference, rtol=rtol, atol=atol))
self.assertTrue(mx.allclose(o, reference, rtol=rtol, atol=atol)) self.assertTrue(mx.allclose(o, reference, rtol=rtol, atol=atol))