mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-24 02:41:19 +08:00
More lenient mask type check in SDPA (#1723)
* check mask type * require promotion
This commit is contained in:
parent
ed4ec81bca
commit
f17536af9c
@ -586,9 +586,9 @@ array scaled_dot_product_attention(
|
|||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (mask && (*mask).dtype() != final_type) {
|
if (mask && promote_types((*mask).dtype(), final_type) != final_type) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[scaled_dot_product_attention] Mask should match output type. "
|
msg << "[scaled_dot_product_attention] Mask type must promote to output type. "
|
||||||
<< final_type << ".";
|
<< final_type << ".";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user