Layer norm throws on dimension mismatch (#2870)

This commit is contained in:
Awni Hannun
2025-12-04 11:21:05 -08:00
committed by GitHub
parent 9abb0b8123
commit 941cfe23d7
2 changed files with 42 additions and 15 deletions

View File

@@ -202,18 +202,36 @@ array layer_norm(
"0 dimensions."; "0 dimensions.";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
if (has_weight && (*weight).ndim() != 1) { if (has_weight) {
if ((*weight).ndim() != 1) {
std::ostringstream msg; std::ostringstream msg;
msg << "[layer_norm] weight must have 1 dimension but has " msg << "[layer_norm] weight must have 1 dimension but has "
<< (*weight).ndim() << " dimensions."; << (*weight).ndim() << " dimensions.";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
if (has_bias && (*bias).ndim() != 1) { if ((*weight).size() != x.shape(-1)) {
std::ostringstream msg; std::ostringstream msg;
msg << "[layer_norm] bias must have 1 dimension but has " << (*bias).ndim() msg << "[layer_norm] weight must have the same size as the last dimension of"
<< " dimensions."; " x but has "
<< (*weight).size() << " elements.";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
}
if (has_bias) {
if ((*bias).ndim() != 1) {
std::ostringstream msg;
msg << "[layer_norm] bias must have 1 dimension but has "
<< (*bias).ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
if ((*bias).size() != x.shape(-1)) {
std::ostringstream msg;
msg << "[layer_norm] bias must have the same size as the last dimension of"
" x but has "
<< (*bias).size() << " elements.";
throw std::invalid_argument(msg.str());
}
}
auto out_type = (has_weight) auto out_type = (has_weight)
? ((has_bias) ? result_type(x, *weight, *bias) : result_type(x, *weight)) ? ((has_bias) ? result_type(x, *weight, *bias) : result_type(x, *weight))

View File

@@ -103,9 +103,6 @@ class TestFast(mlx_tests.MLXTestCase):
dims, _, base, scale, offset, _ = defaults dims, _, base, scale, offset, _ = defaults
for dtype in dtypes: for dtype in dtypes:
x = mx.random.uniform(shape=(2, T, dims)).astype(dtype) x = mx.random.uniform(shape=(2, T, dims)).astype(dtype)
ry = rope_orig(
x.astype(mx.float32), dims, traditional, base, scale, offset
)
rx = rope_orig(x, dims, traditional, base, scale, offset) rx = rope_orig(x, dims, traditional, base, scale, offset)
rx_fast = mx.fast.rope( rx_fast = mx.fast.rope(
x, x,
@@ -116,9 +113,10 @@ class TestFast(mlx_tests.MLXTestCase):
offset=offset, offset=offset,
) )
if dtype != mx.float32: if dtype != mx.float32:
self.assertLessEqual( ry = rope_orig(
mx.abs(ry - rx_fast).max(), mx.abs(ry - rx).max() x.astype(mx.float32), dims, traditional, base, scale, offset
) )
self.assertLess(mx.abs(ry - rx_fast).max(), tolerances[dtype])
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
dims, dtype, base, scale, _, _ = defaults dims, dtype, base, scale, _, _ = defaults
@@ -455,6 +453,17 @@ class TestFast(mlx_tests.MLXTestCase):
self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5) self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)
self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5) self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5)
def test_layer_norm_dim_check(self):
with self.assertRaises(ValueError):
weight = mx.ones((129,))
x = mx.random.randint(low=0, high=10, shape=(4, 128))
mx.fast.layer_norm(x, weight, None, 1e-3)
with self.assertRaises(ValueError):
bias = mx.ones((129,))
x = mx.random.randint(low=0, high=10, shape=(4, 128))
mx.fast.layer_norm(x, None, bias, 1e-3)
def test_layer_norm(self): def test_layer_norm(self):
# Per dtype absolute tolerance # Per dtype absolute tolerance
tolerances = {mx.float32: 1e-5, mx.float16: 5e-3, mx.bfloat16: 5e-2} tolerances = {mx.float32: 1e-5, mx.float16: 5e-3, mx.bfloat16: 5e-2}