mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Allow boolean mask in sdpa (#1753)
* allow boolean mask in sdpa * more permissive donation in ternary
This commit is contained in:
		| @@ -164,9 +164,11 @@ void init_fast(nb::module_& parent_module) { | ||||
|             k (array): Keys with shape ``[B, N_kv, T_kv, D]``. | ||||
|             v (array): Values with shape ``[B, N_kv, T_kv, D]``. | ||||
|             scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``) | ||||
|             mask (array, optional): An additive mask to apply to the query-key | ||||
|                scores. The mask can have at most 4 dimensions and must be | ||||
|                broadcast-compatible with the shape ``[B, N, T_q, T_kv]``. | ||||
|             mask (array, optional): A boolean or additive mask to apply to the | ||||
|                query-key scores. The mask can have at most 4 dimensions and must | ||||
|                be broadcast-compatible with the shape ``[B, N, T_q, T_kv]``. If an | ||||
|                additive mask is given its type must promote to the promoted | ||||
|                type of ``q``, ``k``, and ``v``. | ||||
|         Returns: | ||||
|             array: The output array. | ||||
|       )pbdoc"); | ||||
|   | ||||
| @@ -187,6 +187,17 @@ class TestFastSDPA(mlx_tests.MLXTestCase): | ||||
|         y_hat = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask) | ||||
|         self.assertTrue(mx.allclose(y, y_hat, atol=atol)) | ||||
|  | ||||
|         # Test with boolean causal mask | ||||
|         indices = mx.arange(8) | ||||
|         bool_mask = indices[:, None] >= indices[None] | ||||
|         additive_mask = (~bool_mask).astype(mx.float32) * mx.finfo(mx.float32).min | ||||
|         x = mx.random.normal(shape=(1, 2, 8, 32)) | ||||
|         y = mlx_primitives_sdpa_with_gqa(x, x, x, scale, mask=additive_mask) | ||||
|         y_hat = mx.fast.scaled_dot_product_attention( | ||||
|             x, x, x, scale=scale, mask=bool_mask | ||||
|         ) | ||||
|         self.assertTrue(mx.allclose(y, y_hat, atol=atol)) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main(failfast=True) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun