promote mask when needed (#1998)

This commit is contained in:
Awni Hannun
2025-03-23 19:58:28 -07:00
committed by GitHub
parent f018e248cd
commit a84cc0123f
2 changed files with 18 additions and 0 deletions

View File

@@ -750,6 +750,8 @@ array scaled_dot_product_attention(
msg << "[scaled_dot_product_attention] Mask type must promote to output type. "
<< final_type << ".";
throw std::invalid_argument(msg.str());
} else if (!has_bool_mask) {
mask_arr = astype(mask_arr, final_type, stream);
}
// Broadcast mask
auto mask_shape = queries.shape();