mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	cpu fallback
This commit is contained in:
		| @@ -792,7 +792,8 @@ array quantized_scaled_dot_product_attention( | ||||
|  | ||||
|   int query_head_dim = queries.shape(-1); | ||||
|   int L = queries.shape(2); | ||||
|   if (L > 1 && query_head_dim != 64 && query_head_dim != 128) { | ||||
|   bool compatible_head_dim = query_head_dim == 64 || query_head_dim == 128; | ||||
|   if (L > 1 || !compatible_head_dim || stream.device != Device::gpu) { | ||||
|     if (needs_mask) { | ||||
|       return fallback( | ||||
|           {queries, | ||||
|   | ||||
| @@ -168,7 +168,6 @@ class TestFastSDPA(mlx_tests.MLXTestCase): | ||||
|                 if dtype == mx.float16: | ||||
|                     rtol = 1e-2 | ||||
|  | ||||
|                 # np.testing.assert_allclose(o_q, reference, rtol=rtol, atol=atol) | ||||
|                 self.assertTrue(mx.allclose(o_q, reference, rtol=rtol, atol=atol)) | ||||
|                 self.assertTrue(mx.allclose(o, reference, rtol=rtol, atol=atol)) | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Alex Barron
					Alex Barron