Support fused masking in Attention (#1924)

* Update API to allow mask='causal' in fast::sdpa

* Add fallback

* Update steel::AttnParams

* Fix typo

* WIP, basic causal

* Update tests

* Update benchmarking

* Update masking loop limits

* Add bool masking and update tests

* Update additive mask

* Update benchmarks

* Update benchmarks

* Update tests

* Update for bfloat error

* Update early exit

* Add random seed to tests
This commit is contained in:
Jagrit Digani
2025-03-20 11:01:32 -07:00
committed by GitHub
parent 3c164fca8c
commit 9adcd1a650
11 changed files with 504 additions and 148 deletions

View File

@@ -567,7 +567,7 @@ array scaled_dot_product_attention(
const array& keys,
const array& values,
const float scale,
const std::optional<array>& mask,
const std::variant<std::monostate, std::string, array>& mask /* = {}*/,
const std::optional<int> memory_efficient_threshold,
StreamOrDevice s) {
for (const auto& tensor : {queries, keys, values}) {
@@ -578,10 +578,29 @@ array scaled_dot_product_attention(
throw std::invalid_argument(msg.str());
}
}
if (mask && (*mask).ndim() > 4) {
bool do_causal = false;
bool has_mask = !std::holds_alternative<std::monostate>(mask);
bool has_str_mask = has_mask && std::holds_alternative<std::string>(mask);
bool has_arr_mask = has_mask && std::holds_alternative<array>(mask);
bool has_bool_mask = false;
if (has_str_mask) {
if (std::get<std::string>(mask) != "causal") {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] invalid mask option '"
<< std::get<std::string>(mask) << "'. Must be 'causal', or an array.";
throw std::invalid_argument(msg.str());
} else {
do_causal = true;
}
}
if (has_arr_mask && (std::get<array>(mask)).ndim() > 4) {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] the mask with shape "
<< (*mask).shape() << " expected to have at most rank 4";
<< (std::get<array>(mask)).shape()
<< " expected to have at most rank 4";
throw std::invalid_argument(msg.str());
}
@@ -631,9 +650,11 @@ array scaled_dot_product_attention(
throw std::invalid_argument(msg.str());
}
if (mask) {
if (has_arr_mask) {
// Check type
if (promote_types(mask->dtype(), final_type) != final_type) {
auto mask_arr = std::get<array>(mask);
has_bool_mask = mask_arr.dtype() == bool_;
if (promote_types(mask_arr.dtype(), final_type) != final_type) {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] Mask type must promote to output type. "
<< final_type << ".";
@@ -642,9 +663,10 @@ array scaled_dot_product_attention(
// Check shape
auto mask_shape = queries.shape();
mask_shape.back() = keys.shape(-2);
if (broadcast_shapes(mask->shape(), mask_shape) != mask_shape) {
if (broadcast_shapes(mask_arr.shape(), mask_shape) != mask_shape) {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] Mask with shape " << mask->shape()
msg << "[scaled_dot_product_attention] Mask with shape "
<< mask_arr.shape()
<< " does not broadcast to implicit scores with shape " << mask_shape
<< ".";
throw std::invalid_argument(msg.str());
@@ -662,7 +684,7 @@ array scaled_dot_product_attention(
threshold = std::max(1, memory_efficient_threshold.value());
}
auto fallback = [scale, final_type, n_q_heads, n_kv_heads, s](
auto fallback = [scale, final_type, n_q_heads, n_kv_heads, do_causal, s](
const std::vector<array>& inputs) {
auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s);
int n_repeats = n_q_heads / n_kv_heads;
@@ -676,9 +698,21 @@ array scaled_dot_product_attention(
v = expand_dims(v, 2, s);
}
auto scores = matmul(q, swapaxes(k, -1, -2, s), s);
if (inputs.size() > 3) {
if (inputs.size() > 3 || do_causal) {
// Mask must be broadcast-compatible with [B, n_q_heads, L_q, L_kv]
auto mask = inputs[3];
auto mask = inputs.back();
if (do_causal) {
int kL = k.shape(-2);
int qL = q.shape(-2);
int q_off = (kL - qL) < 0 ? 0 : (kL - qL);
auto q_idx = arange(q_off, q_off + qL, s);
auto k_idx = arange(0, kL, s);
q_idx = expand_dims(q_idx, 1, s);
k_idx = expand_dims(k_idx, 0, s);
mask = greater_equal(q_idx, k_idx, s);
}
if (n_repeats > 1 && mask.ndim() >= 3) {
if (mask.shape(-3) == 1) {
mask = expand_dims(mask, -3, s);
@@ -702,9 +736,10 @@ array scaled_dot_product_attention(
};
auto stream = to_stream(s);
const size_t value_head_dim = v.shape(-1);
const size_t query_head_dim = q.shape(-1);
const size_t query_sequence_length = q.shape(2);
const int value_head_dim = v.shape(-1);
const int query_head_dim = q.shape(-1);
const int query_sequence_length = q.shape(2);
const int key_sequence_length = k.shape(2);
const bool sdpa_vector_supported_head_dim =
query_head_dim == value_head_dim &&
@@ -712,27 +747,33 @@ array scaled_dot_product_attention(
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
const bool supports_sdpa_full = query_sequence_length >= threshold && !mask &&
sdpa_full_supported_head_dim && stream.device == Device::gpu;
const bool sdpa_vector_supported_mask = (!has_mask || has_bool_mask);
const bool sdpa_full_supported_mask = !has_mask || has_arr_mask ||
(query_sequence_length <= key_sequence_length && do_causal);
const bool supports_sdpa_full = query_sequence_length >= threshold &&
sdpa_full_supported_mask && sdpa_full_supported_head_dim &&
stream.device == Device::gpu;
const bool supports_sdpa_vector = (query_sequence_length <= 8) &&
(query_sequence_length <= k.shape(-2)) &&
(!mask || mask->dtype() == bool_) && sdpa_vector_supported_head_dim &&
(query_sequence_length <= key_sequence_length) &&
sdpa_vector_supported_mask && sdpa_vector_supported_head_dim &&
stream.device == Device::gpu;
const bool implementation_supports_use_case =
supports_sdpa_full || supports_sdpa_vector;
std::vector<array> inputs = {q, k, v};
if (mask) {
inputs.push_back(*mask);
if (has_arr_mask) {
inputs.push_back(std::get<array>(mask));
}
if (implementation_supports_use_case) {
auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
return array(
std::move(out_shape),
final_type,
std::make_shared<ScaledDotProductAttention>(stream, fallback, scale),
std::make_shared<ScaledDotProductAttention>(
stream, fallback, scale, do_causal),
std::move(inputs));
}
return fallback(inputs)[0];
@@ -741,7 +782,7 @@ array scaled_dot_product_attention(
bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {
const ScaledDotProductAttention& a_other =
static_cast<const ScaledDotProductAttention&>(other);
return scale_ == a_other.scale_;
return scale_ == a_other.scale_ && do_causal_ == a_other.do_causal_;
}
array pack_and_quantize(