mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
error in rms for wrong size (#1562)
This commit is contained in:
@@ -308,6 +308,11 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
rx_fast = mx.fast.rms_norm(x, weight, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-6)
|
||||
|
||||
# Wrong size w raises
|
||||
with self.assertRaises(ValueError):
|
||||
x = mx.random.uniform(shape=(1, 5))
|
||||
mx.fast.rms_norm(x, mx.ones((4,)), 1e-5)
|
||||
|
||||
def test_rms_norm_grad(self):
|
||||
D = 32
|
||||
eps = 1e-5
|
||||
|
Reference in New Issue
Block a user