diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 731a10bad..02d0398bb 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -69,6 +69,14 @@ array rms_norm( << " dimensions."; throw std::invalid_argument(msg.str()); } + if (weight.size() != x.shape(-1)) { + std::ostringstream msg; + msg << "[rms_norm] weight must have the same size as the last dimension of" + " x but has " + << weight.size() << " elements."; + throw std::invalid_argument(msg.str()); + } + auto out_type = result_type(x, weight); if (!issubdtype(out_type, floating)) { std::ostringstream msg; diff --git a/python/tests/test_fast.py b/python/tests/test_fast.py index f989783a2..d27bdddfb 100644 --- a/python/tests/test_fast.py +++ b/python/tests/test_fast.py @@ -308,6 +308,11 @@ class TestFast(mlx_tests.MLXTestCase): rx_fast = mx.fast.rms_norm(x, weight, eps) self.assertLess(mx.abs(rx - rx_fast).max(), 1e-6) + # Wrong size w raises + with self.assertRaises(ValueError): + x = mx.random.uniform(shape=(1, 5)) + mx.fast.rms_norm(x, mx.ones((4,)), 1e-5) + def test_rms_norm_grad(self): D = 32 eps = 1e-5