mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
error in rms for wrong size (#1562)
This commit is contained in:
parent
f1951d6cce
commit
76f275b4df
@ -69,6 +69,14 @@ array rms_norm(
|
|||||||
<< " dimensions.";
|
<< " dimensions.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
if (weight.size() != x.shape(-1)) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[rms_norm] weight must have the same size as the last dimension of"
|
||||||
|
" x but has "
|
||||||
|
<< weight.size() << " elements.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
auto out_type = result_type(x, weight);
|
auto out_type = result_type(x, weight);
|
||||||
if (!issubdtype(out_type, floating)) {
|
if (!issubdtype(out_type, floating)) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
|
@ -308,6 +308,11 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
rx_fast = mx.fast.rms_norm(x, weight, eps)
|
rx_fast = mx.fast.rms_norm(x, weight, eps)
|
||||||
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-6)
|
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-6)
|
||||||
|
|
||||||
|
# Wrong size w raises
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
x = mx.random.uniform(shape=(1, 5))
|
||||||
|
mx.fast.rms_norm(x, mx.ones((4,)), 1e-5)
|
||||||
|
|
||||||
def test_rms_norm_grad(self):
|
def test_rms_norm_grad(self):
|
||||||
D = 32
|
D = 32
|
||||||
eps = 1e-5
|
eps = 1e-5
|
||||||
|
Loading…
Reference in New Issue
Block a user