mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
fix mask in sdpa (#1980)
* fix mask in sdpa * fix attention mask * Re-enable routing for array mask --------- Co-authored-by: Jagrit Digani <digani@apple.com>
This commit is contained in:
41
mlx/fast.cpp
41
mlx/fast.cpp
@@ -650,29 +650,6 @@ array scaled_dot_product_attention(
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (has_arr_mask) {
|
||||
// Check type
|
||||
auto mask_arr = std::get<array>(mask);
|
||||
has_bool_mask = mask_arr.dtype() == bool_;
|
||||
if (promote_types(mask_arr.dtype(), final_type) != final_type) {
|
||||
std::ostringstream msg;
|
||||
msg << "[scaled_dot_product_attention] Mask type must promote to output type. "
|
||||
<< final_type << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
// Check shape
|
||||
auto mask_shape = queries.shape();
|
||||
mask_shape.back() = keys.shape(-2);
|
||||
if (broadcast_shapes(mask_arr.shape(), mask_shape) != mask_shape) {
|
||||
std::ostringstream msg;
|
||||
msg << "[scaled_dot_product_attention] Mask with shape "
|
||||
<< mask_arr.shape()
|
||||
<< " does not broadcast to implicit scores with shape " << mask_shape
|
||||
<< ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
auto q = astype(queries, final_type, s);
|
||||
auto k = astype(keys, final_type, s);
|
||||
auto v = astype(values, final_type, s);
|
||||
@@ -748,8 +725,8 @@ array scaled_dot_product_attention(
|
||||
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
|
||||
|
||||
const bool sdpa_vector_supported_mask = (!has_mask || has_bool_mask);
|
||||
const bool sdpa_full_supported_mask =
|
||||
!has_mask || (query_sequence_length <= key_sequence_length && do_causal);
|
||||
const bool sdpa_full_supported_mask = !has_mask || has_arr_mask ||
|
||||
(query_sequence_length <= key_sequence_length && do_causal);
|
||||
|
||||
const bool supports_sdpa_full = query_sequence_length >= threshold &&
|
||||
sdpa_full_supported_mask && sdpa_full_supported_head_dim &&
|
||||
@@ -765,7 +742,19 @@ array scaled_dot_product_attention(
|
||||
|
||||
std::vector<array> inputs = {q, k, v};
|
||||
if (has_arr_mask) {
|
||||
inputs.push_back(std::get<array>(mask));
|
||||
// Check type
|
||||
auto mask_arr = std::get<array>(mask);
|
||||
has_bool_mask = mask_arr.dtype() == bool_;
|
||||
if (promote_types(mask_arr.dtype(), final_type) != final_type) {
|
||||
std::ostringstream msg;
|
||||
msg << "[scaled_dot_product_attention] Mask type must promote to output type. "
|
||||
<< final_type << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
// Broadcast mask
|
||||
auto mask_shape = queries.shape();
|
||||
mask_shape.back() = keys.shape(-2);
|
||||
inputs.push_back(broadcast_to(mask_arr, mask_shape, stream));
|
||||
}
|
||||
if (implementation_supports_use_case) {
|
||||
auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
|
||||
|
||||
Reference in New Issue
Block a user