mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-30 06:12:41 +08:00
cpu fallback
This commit is contained in:
parent
c89ddf62b4
commit
769704653a
@ -792,7 +792,8 @@ array quantized_scaled_dot_product_attention(
|
|||||||
|
|
||||||
int query_head_dim = queries.shape(-1);
|
int query_head_dim = queries.shape(-1);
|
||||||
int L = queries.shape(2);
|
int L = queries.shape(2);
|
||||||
if (L > 1 && query_head_dim != 64 && query_head_dim != 128) {
|
bool compatible_head_dim = query_head_dim == 64 || query_head_dim == 128;
|
||||||
|
if (L > 1 || !compatible_head_dim || stream.device != Device::gpu) {
|
||||||
if (needs_mask) {
|
if (needs_mask) {
|
||||||
return fallback(
|
return fallback(
|
||||||
{queries,
|
{queries,
|
||||||
|
@ -168,7 +168,6 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
|||||||
if dtype == mx.float16:
|
if dtype == mx.float16:
|
||||||
rtol = 1e-2
|
rtol = 1e-2
|
||||||
|
|
||||||
# np.testing.assert_allclose(o_q, reference, rtol=rtol, atol=atol)
|
|
||||||
self.assertTrue(mx.allclose(o_q, reference, rtol=rtol, atol=atol))
|
self.assertTrue(mx.allclose(o_q, reference, rtol=rtol, atol=atol))
|
||||||
self.assertTrue(mx.allclose(o, reference, rtol=rtol, atol=atol))
|
self.assertTrue(mx.allclose(o, reference, rtol=rtol, atol=atol))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user