Allow boolean mask in sdpa (#1753)

* allow boolean mask in sdpa

* more permissive donation in ternary
This commit is contained in:
Awni Hannun 2025-01-06 16:57:07 -08:00 committed by GitHub
parent 25b3a3e541
commit d5ec172c95
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 28 additions and 5 deletions

View File

@ -67,7 +67,12 @@ void set_ternary_op_output_data(
}
break;
case TernaryOpType::General:
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// Try to donate an input which is row_contiguous
if (!((a.flags().row_contiguous && maybe_donate(a)) ||
(b.flags().row_contiguous && maybe_donate(b)) ||
(c.flags().row_contiguous && maybe_donate(c)))) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
break;
}
}

View File

@ -659,7 +659,12 @@ array scaled_dot_product_attention(
mask = unflatten(mask, -3, {n_kv_heads, n_repeats}, s);
}
}
scores = add(scores, mask, s);
if (mask.dtype() == bool_) {
scores = where(
mask, scores, array(finfo(scores.dtype()).min, scores.dtype()));
} else {
scores = add(scores, mask, s);
}
}
scores = softmax(scores, std::vector<int>{-1}, true, s);
auto out = matmul(scores, v, s);

View File

@ -164,9 +164,11 @@ void init_fast(nb::module_& parent_module) {
k (array): Keys with shape ``[B, N_kv, T_kv, D]``.
v (array): Values with shape ``[B, N_kv, T_kv, D]``.
scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``)
mask (array, optional): An additive mask to apply to the query-key
scores. The mask can have at most 4 dimensions and must be
broadcast-compatible with the shape ``[B, N, T_q, T_kv]``.
mask (array, optional): A boolean or additive mask to apply to the
query-key scores. The mask can have at most 4 dimensions and must
be broadcast-compatible with the shape ``[B, N, T_q, T_kv]``. If an
additive mask is given its type must promote to the promoted
type of ``q``, ``k``, and ``v``.
Returns:
array: The output array.
)pbdoc");

View File

@ -187,6 +187,17 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
y_hat = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
self.assertTrue(mx.allclose(y, y_hat, atol=atol))
# Test with boolean causal mask
indices = mx.arange(8)
bool_mask = indices[:, None] >= indices[None]
additive_mask = (~bool_mask).astype(mx.float32) * mx.finfo(mx.float32).min
x = mx.random.normal(shape=(1, 2, 8, 32))
y = mlx_primitives_sdpa_with_gqa(x, x, x, scale, mask=additive_mask)
y_hat = mx.fast.scaled_dot_product_attention(
x, x, x, scale=scale, mask=bool_mask
)
self.assertTrue(mx.allclose(y, y_hat, atol=atol))
if __name__ == "__main__":
unittest.main(failfast=True)