mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
error in rms for wrong size (#1562)
This commit is contained in:
@@ -69,6 +69,14 @@ array rms_norm(
|
||||
<< " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (weight.size() != x.shape(-1)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[rms_norm] weight must have the same size as the last dimension of"
|
||||
" x but has "
|
||||
<< weight.size() << " elements.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto out_type = result_type(x, weight);
|
||||
if (!issubdtype(out_type, floating)) {
|
||||
std::ostringstream msg;
|
||||
|
||||
Reference in New Issue
Block a user