From 941cfe23d7b7cc549506782f3b65621163e661ae Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 4 Dec 2025 11:21:05 -0800 Subject: [PATCH] Layer norm throws on dimension mismatch (#2870) --- mlx/fast.cpp | 38 ++++++++++++++++++++++++++++---------- python/tests/test_fast.py | 19 ++++++++++++++----- 2 files changed, 42 insertions(+), 15 deletions(-) diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 1ad42d0cf..97a3a5f6a 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -202,17 +202,35 @@ array layer_norm( "0 dimensions."; throw std::invalid_argument(msg.str()); } - if (has_weight && (*weight).ndim() != 1) { - std::ostringstream msg; - msg << "[layer_norm] weight must have 1 dimension but has " - << (*weight).ndim() << " dimensions."; - throw std::invalid_argument(msg.str()); + if (has_weight) { + if ((*weight).ndim() != 1) { + std::ostringstream msg; + msg << "[layer_norm] weight must have 1 dimension but has " + << (*weight).ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + if ((*weight).size() != x.shape(-1)) { + std::ostringstream msg; + msg << "[layer_norm] weight must have the same size as the last dimension of" + " x but has " + << (*weight).size() << " elements."; + throw std::invalid_argument(msg.str()); + } } - if (has_bias && (*bias).ndim() != 1) { - std::ostringstream msg; - msg << "[layer_norm] bias must have 1 dimension but has " << (*bias).ndim() - << " dimensions."; - throw std::invalid_argument(msg.str()); + if (has_bias) { + if ((*bias).ndim() != 1) { + std::ostringstream msg; + msg << "[layer_norm] bias must have 1 dimension but has " + << (*bias).ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + if ((*bias).size() != x.shape(-1)) { + std::ostringstream msg; + msg << "[layer_norm] bias must have the same size as the last dimension of" + " x but has " + << (*bias).size() << " elements."; + throw std::invalid_argument(msg.str()); + } } auto out_type = (has_weight) diff --git a/python/tests/test_fast.py b/python/tests/test_fast.py index f16aee05d..63c5b06ac 100644 --- a/python/tests/test_fast.py +++ b/python/tests/test_fast.py @@ -103,9 +103,6 @@ class TestFast(mlx_tests.MLXTestCase): dims, _, base, scale, offset, _ = defaults for dtype in dtypes: x = mx.random.uniform(shape=(2, T, dims)).astype(dtype) - ry = rope_orig( - x.astype(mx.float32), dims, traditional, base, scale, offset - ) rx = rope_orig(x, dims, traditional, base, scale, offset) rx_fast = mx.fast.rope( x, @@ -116,9 +113,10 @@ class TestFast(mlx_tests.MLXTestCase): offset=offset, ) if dtype != mx.float32: - self.assertLessEqual( - mx.abs(ry - rx_fast).max(), mx.abs(ry - rx).max() + ry = rope_orig( + x.astype(mx.float32), dims, traditional, base, scale, offset ) + self.assertLess(mx.abs(ry - rx_fast).max(), tolerances[dtype]) self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) dims, dtype, base, scale, _, _ = defaults @@ -455,6 +453,17 @@ class TestFast(mlx_tests.MLXTestCase): self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5) self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5) + def test_layer_norm_dim_check(self): + with self.assertRaises(ValueError): + weight = mx.ones((129,)) + x = mx.random.randint(low=0, high=10, shape=(4, 128)) + mx.fast.layer_norm(x, weight, None, 1e-3) + + with self.assertRaises(ValueError): + bias = mx.ones((129,)) + x = mx.random.randint(low=0, high=10, shape=(4, 128)) + mx.fast.layer_norm(x, None, bias, 1e-3) + def test_layer_norm(self): # Per dtype absolute tolerance tolerances = {mx.float32: 1e-5, mx.float16: 5e-3, mx.bfloat16: 5e-2}