error in rms for wrong size (#1562)

This commit is contained in:
Awni Hannun
2024-11-04 13:24:02 -08:00
committed by GitHub
parent f1951d6cce
commit 76f275b4df
2 changed files with 13 additions and 0 deletions

View File

@@ -69,6 +69,14 @@ array rms_norm(
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
if (weight.size() != x.shape(-1)) {
std::ostringstream msg;
msg << "[rms_norm] weight must have the same size as the last dimension of"
" x but has "
<< weight.size() << " elements.";
throw std::invalid_argument(msg.str());
}
auto out_type = result_type(x, weight);
if (!issubdtype(out_type, floating)) {
std::ostringstream msg;