Allow boolean mask in sdpa (#1753)

* allow boolean mask in sdpa

* more permissive donation in ternary
This commit is contained in:
Awni Hannun
2025-01-06 16:57:07 -08:00
committed by GitHub
parent 25b3a3e541
commit d5ec172c95
4 changed files with 28 additions and 5 deletions

View File

@@ -659,7 +659,12 @@ array scaled_dot_product_attention(
mask = unflatten(mask, -3, {n_kv_heads, n_repeats}, s);
}
}
scores = add(scores, mask, s);
if (mask.dtype() == bool_) {
scores = where(
mask, scores, array(finfo(scores.dtype()).min, scores.dtype()));
} else {
scores = add(scores, mask, s);
}
}
scores = softmax(scores, std::vector<int>{-1}, true, s);
auto out = matmul(scores, v, s);