mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-14 05:06:39 +08:00
check mask type (#1721)
This commit is contained in:
parent
5548fcc96d
commit
49c34c4161
@ -586,6 +586,13 @@ array scaled_dot_product_attention(
|
|||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (mask && (*mask).dtype() != final_type) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[scaled_dot_product_attention] Mask should match output type. "
|
||||||
|
<< final_type << ".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
auto q = astype(queries, final_type, s);
|
auto q = astype(queries, final_type, s);
|
||||||
auto k = astype(keys, final_type, s);
|
auto k = astype(keys, final_type, s);
|
||||||
auto v = astype(values, final_type, s);
|
auto v = astype(values, final_type, s);
|
||||||
|
Loading…
Reference in New Issue
Block a user