error in rms for wrong size (#1562)

This commit is contained in:
Awni Hannun 2024-11-04 13:24:02 -08:00 committed by GitHub
parent f1951d6cce
commit 76f275b4df
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 13 additions and 0 deletions

View File

@ -69,6 +69,14 @@ array rms_norm(
<< " dimensions."; << " dimensions.";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
if (weight.size() != x.shape(-1)) {
std::ostringstream msg;
msg << "[rms_norm] weight must have the same size as the last dimension of"
" x but has "
<< weight.size() << " elements.";
throw std::invalid_argument(msg.str());
}
auto out_type = result_type(x, weight); auto out_type = result_type(x, weight);
if (!issubdtype(out_type, floating)) { if (!issubdtype(out_type, floating)) {
std::ostringstream msg; std::ostringstream msg;

View File

@ -308,6 +308,11 @@ class TestFast(mlx_tests.MLXTestCase):
rx_fast = mx.fast.rms_norm(x, weight, eps) rx_fast = mx.fast.rms_norm(x, weight, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-6) self.assertLess(mx.abs(rx - rx_fast).max(), 1e-6)
# Wrong size w raises
with self.assertRaises(ValueError):
x = mx.random.uniform(shape=(1, 5))
mx.fast.rms_norm(x, mx.ones((4,)), 1e-5)
def test_rms_norm_grad(self): def test_rms_norm_grad(self):
D = 32 D = 32
eps = 1e-5 eps = 1e-5