check mask type (#1721)

This commit is contained in:
Alex Barron 2024-12-18 14:25:18 -08:00 committed by GitHub
parent 5548fcc96d
commit 49c34c4161
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -586,6 +586,13 @@ array scaled_dot_product_attention(
throw std::invalid_argument(msg.str());
}
if (mask && (*mask).dtype() != final_type) {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] Mask should match output type. "
<< final_type << ".";
throw std::invalid_argument(msg.str());
}
auto q = astype(queries, final_type, s);
auto k = astype(keys, final_type, s);
auto v = astype(values, final_type, s);