mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Support fused masking in Attention (#1924)
* Update API to allow mask='causal' in fast::sdpa * Add fallback * Update steel::AttnParams * Fix typo * WIP, basic causal * Update tests * Update benchmarking * Update masking loop limits * Add bool masking and update tests * Update additive mask * Update benchmarks * Update benchmarks * Update tests * Update for bfloat error * Update early exit * Add random seed to tests
This commit is contained in:
83
mlx/fast.cpp
83
mlx/fast.cpp
@@ -567,7 +567,7 @@ array scaled_dot_product_attention(
|
||||
const array& keys,
|
||||
const array& values,
|
||||
const float scale,
|
||||
const std::optional<array>& mask,
|
||||
const std::variant<std::monostate, std::string, array>& mask /* = {}*/,
|
||||
const std::optional<int> memory_efficient_threshold,
|
||||
StreamOrDevice s) {
|
||||
for (const auto& tensor : {queries, keys, values}) {
|
||||
@@ -578,10 +578,29 @@ array scaled_dot_product_attention(
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
if (mask && (*mask).ndim() > 4) {
|
||||
|
||||
bool do_causal = false;
|
||||
bool has_mask = !std::holds_alternative<std::monostate>(mask);
|
||||
bool has_str_mask = has_mask && std::holds_alternative<std::string>(mask);
|
||||
bool has_arr_mask = has_mask && std::holds_alternative<array>(mask);
|
||||
bool has_bool_mask = false;
|
||||
|
||||
if (has_str_mask) {
|
||||
if (std::get<std::string>(mask) != "causal") {
|
||||
std::ostringstream msg;
|
||||
msg << "[scaled_dot_product_attention] invalid mask option '"
|
||||
<< std::get<std::string>(mask) << "'. Must be 'causal', or an array.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
} else {
|
||||
do_causal = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (has_arr_mask && (std::get<array>(mask)).ndim() > 4) {
|
||||
std::ostringstream msg;
|
||||
msg << "[scaled_dot_product_attention] the mask with shape "
|
||||
<< (*mask).shape() << " expected to have at most rank 4";
|
||||
<< (std::get<array>(mask)).shape()
|
||||
<< " expected to have at most rank 4";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
@@ -631,9 +650,11 @@ array scaled_dot_product_attention(
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (mask) {
|
||||
if (has_arr_mask) {
|
||||
// Check type
|
||||
if (promote_types(mask->dtype(), final_type) != final_type) {
|
||||
auto mask_arr = std::get<array>(mask);
|
||||
has_bool_mask = mask_arr.dtype() == bool_;
|
||||
if (promote_types(mask_arr.dtype(), final_type) != final_type) {
|
||||
std::ostringstream msg;
|
||||
msg << "[scaled_dot_product_attention] Mask type must promote to output type. "
|
||||
<< final_type << ".";
|
||||
@@ -642,9 +663,10 @@ array scaled_dot_product_attention(
|
||||
// Check shape
|
||||
auto mask_shape = queries.shape();
|
||||
mask_shape.back() = keys.shape(-2);
|
||||
if (broadcast_shapes(mask->shape(), mask_shape) != mask_shape) {
|
||||
if (broadcast_shapes(mask_arr.shape(), mask_shape) != mask_shape) {
|
||||
std::ostringstream msg;
|
||||
msg << "[scaled_dot_product_attention] Mask with shape " << mask->shape()
|
||||
msg << "[scaled_dot_product_attention] Mask with shape "
|
||||
<< mask_arr.shape()
|
||||
<< " does not broadcast to implicit scores with shape " << mask_shape
|
||||
<< ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
@@ -662,7 +684,7 @@ array scaled_dot_product_attention(
|
||||
threshold = std::max(1, memory_efficient_threshold.value());
|
||||
}
|
||||
|
||||
auto fallback = [scale, final_type, n_q_heads, n_kv_heads, s](
|
||||
auto fallback = [scale, final_type, n_q_heads, n_kv_heads, do_causal, s](
|
||||
const std::vector<array>& inputs) {
|
||||
auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s);
|
||||
int n_repeats = n_q_heads / n_kv_heads;
|
||||
@@ -676,9 +698,21 @@ array scaled_dot_product_attention(
|
||||
v = expand_dims(v, 2, s);
|
||||
}
|
||||
auto scores = matmul(q, swapaxes(k, -1, -2, s), s);
|
||||
if (inputs.size() > 3) {
|
||||
if (inputs.size() > 3 || do_causal) {
|
||||
// Mask must be broadcast-compatible with [B, n_q_heads, L_q, L_kv]
|
||||
auto mask = inputs[3];
|
||||
auto mask = inputs.back();
|
||||
|
||||
if (do_causal) {
|
||||
int kL = k.shape(-2);
|
||||
int qL = q.shape(-2);
|
||||
int q_off = (kL - qL) < 0 ? 0 : (kL - qL);
|
||||
auto q_idx = arange(q_off, q_off + qL, s);
|
||||
auto k_idx = arange(0, kL, s);
|
||||
q_idx = expand_dims(q_idx, 1, s);
|
||||
k_idx = expand_dims(k_idx, 0, s);
|
||||
mask = greater_equal(q_idx, k_idx, s);
|
||||
}
|
||||
|
||||
if (n_repeats > 1 && mask.ndim() >= 3) {
|
||||
if (mask.shape(-3) == 1) {
|
||||
mask = expand_dims(mask, -3, s);
|
||||
@@ -702,9 +736,10 @@ array scaled_dot_product_attention(
|
||||
};
|
||||
|
||||
auto stream = to_stream(s);
|
||||
const size_t value_head_dim = v.shape(-1);
|
||||
const size_t query_head_dim = q.shape(-1);
|
||||
const size_t query_sequence_length = q.shape(2);
|
||||
const int value_head_dim = v.shape(-1);
|
||||
const int query_head_dim = q.shape(-1);
|
||||
const int query_sequence_length = q.shape(2);
|
||||
const int key_sequence_length = k.shape(2);
|
||||
|
||||
const bool sdpa_vector_supported_head_dim =
|
||||
query_head_dim == value_head_dim &&
|
||||
@@ -712,27 +747,33 @@ array scaled_dot_product_attention(
|
||||
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
|
||||
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
|
||||
|
||||
const bool supports_sdpa_full = query_sequence_length >= threshold && !mask &&
|
||||
sdpa_full_supported_head_dim && stream.device == Device::gpu;
|
||||
const bool sdpa_vector_supported_mask = (!has_mask || has_bool_mask);
|
||||
const bool sdpa_full_supported_mask = !has_mask || has_arr_mask ||
|
||||
(query_sequence_length <= key_sequence_length && do_causal);
|
||||
|
||||
const bool supports_sdpa_full = query_sequence_length >= threshold &&
|
||||
sdpa_full_supported_mask && sdpa_full_supported_head_dim &&
|
||||
stream.device == Device::gpu;
|
||||
|
||||
const bool supports_sdpa_vector = (query_sequence_length <= 8) &&
|
||||
(query_sequence_length <= k.shape(-2)) &&
|
||||
(!mask || mask->dtype() == bool_) && sdpa_vector_supported_head_dim &&
|
||||
(query_sequence_length <= key_sequence_length) &&
|
||||
sdpa_vector_supported_mask && sdpa_vector_supported_head_dim &&
|
||||
stream.device == Device::gpu;
|
||||
|
||||
const bool implementation_supports_use_case =
|
||||
supports_sdpa_full || supports_sdpa_vector;
|
||||
|
||||
std::vector<array> inputs = {q, k, v};
|
||||
if (mask) {
|
||||
inputs.push_back(*mask);
|
||||
if (has_arr_mask) {
|
||||
inputs.push_back(std::get<array>(mask));
|
||||
}
|
||||
if (implementation_supports_use_case) {
|
||||
auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
final_type,
|
||||
std::make_shared<ScaledDotProductAttention>(stream, fallback, scale),
|
||||
std::make_shared<ScaledDotProductAttention>(
|
||||
stream, fallback, scale, do_causal),
|
||||
std::move(inputs));
|
||||
}
|
||||
return fallback(inputs)[0];
|
||||
@@ -741,7 +782,7 @@ array scaled_dot_product_attention(
|
||||
bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {
|
||||
const ScaledDotProductAttention& a_other =
|
||||
static_cast<const ScaledDotProductAttention&>(other);
|
||||
return scale_ == a_other.scale_;
|
||||
return scale_ == a_other.scale_ && do_causal_ == a_other.do_causal_;
|
||||
}
|
||||
|
||||
array pack_and_quantize(
|
||||
|
||||
Reference in New Issue
Block a user