mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-29 06:53:18 +08:00 
			
		
		
		
	full row mask in sdpa consistently gives nan (#2406)
This commit is contained in:
		| @@ -708,7 +708,10 @@ array scaled_dot_product_attention( | ||||
|       } | ||||
|       if (mask.dtype() == bool_) { | ||||
|         scores = where( | ||||
|             mask, scores, array(finfo(scores.dtype()).min, scores.dtype())); | ||||
|             mask, | ||||
|             scores, | ||||
|             array(-std::numeric_limits<float>::infinity(), scores.dtype()), | ||||
|             s); | ||||
|       } else { | ||||
|         scores = add(scores, mask, s); | ||||
|       } | ||||
|   | ||||
| @@ -398,6 +398,18 @@ class TestFastSDPA(mlx_tests.MLXTestCase): | ||||
|             ) | ||||
|             self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) | ||||
|  | ||||
|     def test_fully_masked(self): | ||||
|         Lkv = 8 | ||||
|         mask = mx.array(False) | ||||
|         for D in [4, 128]: | ||||
|             for Lq in [1, 8]: | ||||
|                 q = mx.random.normal(shape=(1, 4, Lq, D)) | ||||
|                 k = mx.random.normal(shape=(1, 4, Lkv, D)) | ||||
|                 v = mx.random.normal(shape=(1, 4, Lkv, D)) | ||||
|  | ||||
|                 out = mx.fast.scaled_dot_product_attention(q, k, v, mask=mask, scale=1) | ||||
|                 self.assertTrue(mx.all(mx.isnan(out))) | ||||
|  | ||||
|     def test_fast_sdpa_few_query(self): | ||||
|         D = 64 | ||||
|         L = 43 | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun