Add boolean mask support in vector SDPA (#1757)

This commit is contained in:
Awni Hannun
2025-01-07 20:24:53 -08:00
committed by GitHub
parent 516ded618b
commit d1766f2c70
5 changed files with 226 additions and 74 deletions

View File

@@ -609,27 +609,32 @@ array scaled_dot_product_attention(
throw std::invalid_argument(msg.str());
}
if (mask && promote_types((*mask).dtype(), final_type) != final_type) {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] Mask type must promote to output type. "
<< final_type << ".";
throw std::invalid_argument(msg.str());
if (mask) {
// Check type
if (promote_types(mask->dtype(), final_type) != final_type) {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] Mask type must promote to output type. "
<< final_type << ".";
throw std::invalid_argument(msg.str());
}
// Check shape
auto mask_shape = queries.shape();
mask_shape.back() = keys.shape(-2);
if (broadcast_shapes(mask->shape(), mask_shape) != mask_shape) {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] Mask with shape " << mask->shape()
<< " does not broadcast to implicit scores with shape " << mask_shape
<< ".";
throw std::invalid_argument(msg.str());
}
}
auto q = astype(queries, final_type, s);
auto k = astype(keys, final_type, s);
auto v = astype(values, final_type, s);
/* generic implementation for use cases that Metal implementation does not
* support. For non-supported cases listed below, use MLX primitives:
* * CPU implementation
* * batch size > 1 for decoding or causal attention
* * query sequence length > 1 for decoding
* * query sequence length > 16 && non-null mask (causal attention)
* * non-null mask
* * dtype is not fp32 or fp16
*/
/* Generic implementation for use cases that Metal implementation does not
* support. */
int threshold = 32; // TODO: Fix after dev
if (memory_efficient_threshold.has_value()) {
threshold = std::max(1, memory_efficient_threshold.value());
@@ -690,27 +695,27 @@ array scaled_dot_product_attention(
!mask.has_value() && sdpa_full_supported_head_dim &&
stream.device == Device::gpu;
const bool supported_mask = !mask || (mask->dtype() == bool_);
const bool supports_sdpa_vector = query_sequence_length == 1 &&
!mask.has_value() && sdpa_vector_supported_head_dim &&
supported_mask && sdpa_vector_supported_head_dim &&
stream.device == Device::gpu;
implementation_supports_use_case &=
supports_sdpa_full || supports_sdpa_vector;
std::vector<array> inputs = {q, k, v};
if (mask) {
inputs.push_back(*mask);
}
if (implementation_supports_use_case) {
auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
return array(
std::move(out_shape),
final_type,
std::make_shared<ScaledDotProductAttention>(stream, fallback, scale),
{q, k, v});
}
if (mask.has_value()) {
return fallback({q, k, v, mask.value()})[0];
} else {
return fallback({q, k, v})[0];
std::move(inputs));
}
return fallback(inputs)[0];
}
bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {