More lenient mask type check in SDPA (#1723)

* check mask type

* require promotion
This commit is contained in:
Alex Barron 2024-12-18 19:41:38 -08:00 committed by GitHub
parent ed4ec81bca
commit f17536af9c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -586,9 +586,9 @@ array scaled_dot_product_attention(
throw std::invalid_argument(msg.str());
}
if (mask && (*mask).dtype() != final_type) {
if (mask && promote_types((*mask).dtype(), final_type) != final_type) {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] Mask should match output type. "
msg << "[scaled_dot_product_attention] Mask type must promote to output type. "
<< final_type << ".";
throw std::invalid_argument(msg.str());
}