mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add boolean mask support in vector SDPA (#1757)
This commit is contained in:
51
mlx/fast.cpp
51
mlx/fast.cpp
@@ -609,27 +609,32 @@ array scaled_dot_product_attention(
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (mask && promote_types((*mask).dtype(), final_type) != final_type) {
|
||||
std::ostringstream msg;
|
||||
msg << "[scaled_dot_product_attention] Mask type must promote to output type. "
|
||||
<< final_type << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
if (mask) {
|
||||
// Check type
|
||||
if (promote_types(mask->dtype(), final_type) != final_type) {
|
||||
std::ostringstream msg;
|
||||
msg << "[scaled_dot_product_attention] Mask type must promote to output type. "
|
||||
<< final_type << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
// Check shape
|
||||
auto mask_shape = queries.shape();
|
||||
mask_shape.back() = keys.shape(-2);
|
||||
if (broadcast_shapes(mask->shape(), mask_shape) != mask_shape) {
|
||||
std::ostringstream msg;
|
||||
msg << "[scaled_dot_product_attention] Mask with shape " << mask->shape()
|
||||
<< " does not broadcast to implicit scores with shape " << mask_shape
|
||||
<< ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
auto q = astype(queries, final_type, s);
|
||||
auto k = astype(keys, final_type, s);
|
||||
auto v = astype(values, final_type, s);
|
||||
|
||||
/* generic implementation for use cases that Metal implementation does not
|
||||
* support. For non-supported cases listed below, use MLX primitives:
|
||||
* * CPU implementation
|
||||
* * batch size > 1 for decoding or causal attention
|
||||
* * query sequence length > 1 for decoding
|
||||
* * query sequence length > 16 && non-null mask (causal attention)
|
||||
* * non-null mask
|
||||
* * dtype is not fp32 or fp16
|
||||
*/
|
||||
|
||||
/* Generic implementation for use cases that Metal implementation does not
|
||||
* support. */
|
||||
int threshold = 32; // TODO: Fix after dev
|
||||
if (memory_efficient_threshold.has_value()) {
|
||||
threshold = std::max(1, memory_efficient_threshold.value());
|
||||
@@ -690,27 +695,27 @@ array scaled_dot_product_attention(
|
||||
!mask.has_value() && sdpa_full_supported_head_dim &&
|
||||
stream.device == Device::gpu;
|
||||
|
||||
const bool supported_mask = !mask || (mask->dtype() == bool_);
|
||||
const bool supports_sdpa_vector = query_sequence_length == 1 &&
|
||||
!mask.has_value() && sdpa_vector_supported_head_dim &&
|
||||
supported_mask && sdpa_vector_supported_head_dim &&
|
||||
stream.device == Device::gpu;
|
||||
|
||||
implementation_supports_use_case &=
|
||||
supports_sdpa_full || supports_sdpa_vector;
|
||||
|
||||
std::vector<array> inputs = {q, k, v};
|
||||
if (mask) {
|
||||
inputs.push_back(*mask);
|
||||
}
|
||||
if (implementation_supports_use_case) {
|
||||
auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
final_type,
|
||||
std::make_shared<ScaledDotProductAttention>(stream, fallback, scale),
|
||||
{q, k, v});
|
||||
}
|
||||
|
||||
if (mask.has_value()) {
|
||||
return fallback({q, k, v, mask.value()})[0];
|
||||
} else {
|
||||
return fallback({q, k, v})[0];
|
||||
std::move(inputs));
|
||||
}
|
||||
return fallback(inputs)[0];
|
||||
}
|
||||
|
||||
bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {
|
||||
|
||||
Reference in New Issue
Block a user