diff --git a/mlx/backend/common/ternary.h b/mlx/backend/common/ternary.h index 050e07e022..eb513a12f4 100644 --- a/mlx/backend/common/ternary.h +++ b/mlx/backend/common/ternary.h @@ -67,7 +67,12 @@ void set_ternary_op_output_data( } break; case TernaryOpType::General: - out.set_data(allocator::malloc_or_wait(out.nbytes())); + // Try to donate an input which is row_contiguous + if (!((a.flags().row_contiguous && maybe_donate(a)) || + (b.flags().row_contiguous && maybe_donate(b)) || + (c.flags().row_contiguous && maybe_donate(c)))) { + out.set_data(allocator::malloc_or_wait(out.nbytes())); + } break; } } diff --git a/mlx/fast.cpp b/mlx/fast.cpp index ca80289a9c..7b3570b903 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -659,7 +659,12 @@ array scaled_dot_product_attention( mask = unflatten(mask, -3, {n_kv_heads, n_repeats}, s); } } - scores = add(scores, mask, s); + if (mask.dtype() == bool_) { + scores = where( + mask, scores, array(finfo(scores.dtype()).min, scores.dtype())); + } else { + scores = add(scores, mask, s); + } } scores = softmax(scores, std::vector{-1}, true, s); auto out = matmul(scores, v, s); diff --git a/python/src/fast.cpp b/python/src/fast.cpp index 91a7293fb8..d7ccc000b4 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -164,9 +164,11 @@ void init_fast(nb::module_& parent_module) { k (array): Keys with shape ``[B, N_kv, T_kv, D]``. v (array): Values with shape ``[B, N_kv, T_kv, D]``. scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``) - mask (array, optional): An additive mask to apply to the query-key - scores. The mask can have at most 4 dimensions and must be - broadcast-compatible with the shape ``[B, N, T_q, T_kv]``. + mask (array, optional): A boolean or additive mask to apply to the + query-key scores. The mask can have at most 4 dimensions and must + be broadcast-compatible with the shape ``[B, N, T_q, T_kv]``. If an + additive mask is given its type must promote to the promoted + type of ``q``, ``k``, and ``v``. Returns: array: The output array. )pbdoc"); diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index 1df48bc7f7..3b86ef17de 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -187,6 +187,17 @@ class TestFastSDPA(mlx_tests.MLXTestCase): y_hat = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask) self.assertTrue(mx.allclose(y, y_hat, atol=atol)) + # Test with boolean causal mask + indices = mx.arange(8) + bool_mask = indices[:, None] >= indices[None] + additive_mask = (~bool_mask).astype(mx.float32) * mx.finfo(mx.float32).min + x = mx.random.normal(shape=(1, 2, 8, 32)) + y = mlx_primitives_sdpa_with_gqa(x, x, x, scale, mask=additive_mask) + y_hat = mx.fast.scaled_dot_product_attention( + x, x, x, scale=scale, mask=bool_mask + ) + self.assertTrue(mx.allclose(y, y_hat, atol=atol)) + if __name__ == "__main__": unittest.main(failfast=True)