mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Layer norm throws on dimension mismatch (#2870)
This commit is contained in:
38
mlx/fast.cpp
38
mlx/fast.cpp
@@ -202,17 +202,35 @@ array layer_norm(
|
|||||||
"0 dimensions.";
|
"0 dimensions.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
if (has_weight && (*weight).ndim() != 1) {
|
if (has_weight) {
|
||||||
std::ostringstream msg;
|
if ((*weight).ndim() != 1) {
|
||||||
msg << "[layer_norm] weight must have 1 dimension but has "
|
std::ostringstream msg;
|
||||||
<< (*weight).ndim() << " dimensions.";
|
msg << "[layer_norm] weight must have 1 dimension but has "
|
||||||
throw std::invalid_argument(msg.str());
|
<< (*weight).ndim() << " dimensions.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
if ((*weight).size() != x.shape(-1)) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[layer_norm] weight must have the same size as the last dimension of"
|
||||||
|
" x but has "
|
||||||
|
<< (*weight).size() << " elements.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (has_bias && (*bias).ndim() != 1) {
|
if (has_bias) {
|
||||||
std::ostringstream msg;
|
if ((*bias).ndim() != 1) {
|
||||||
msg << "[layer_norm] bias must have 1 dimension but has " << (*bias).ndim()
|
std::ostringstream msg;
|
||||||
<< " dimensions.";
|
msg << "[layer_norm] bias must have 1 dimension but has "
|
||||||
throw std::invalid_argument(msg.str());
|
<< (*bias).ndim() << " dimensions.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
if ((*bias).size() != x.shape(-1)) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[layer_norm] bias must have the same size as the last dimension of"
|
||||||
|
" x but has "
|
||||||
|
<< (*bias).size() << " elements.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto out_type = (has_weight)
|
auto out_type = (has_weight)
|
||||||
|
|||||||
@@ -103,9 +103,6 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
dims, _, base, scale, offset, _ = defaults
|
dims, _, base, scale, offset, _ = defaults
|
||||||
for dtype in dtypes:
|
for dtype in dtypes:
|
||||||
x = mx.random.uniform(shape=(2, T, dims)).astype(dtype)
|
x = mx.random.uniform(shape=(2, T, dims)).astype(dtype)
|
||||||
ry = rope_orig(
|
|
||||||
x.astype(mx.float32), dims, traditional, base, scale, offset
|
|
||||||
)
|
|
||||||
rx = rope_orig(x, dims, traditional, base, scale, offset)
|
rx = rope_orig(x, dims, traditional, base, scale, offset)
|
||||||
rx_fast = mx.fast.rope(
|
rx_fast = mx.fast.rope(
|
||||||
x,
|
x,
|
||||||
@@ -116,9 +113,10 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
offset=offset,
|
offset=offset,
|
||||||
)
|
)
|
||||||
if dtype != mx.float32:
|
if dtype != mx.float32:
|
||||||
self.assertLessEqual(
|
ry = rope_orig(
|
||||||
mx.abs(ry - rx_fast).max(), mx.abs(ry - rx).max()
|
x.astype(mx.float32), dims, traditional, base, scale, offset
|
||||||
)
|
)
|
||||||
|
self.assertLess(mx.abs(ry - rx_fast).max(), tolerances[dtype])
|
||||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||||
|
|
||||||
dims, dtype, base, scale, _, _ = defaults
|
dims, dtype, base, scale, _, _ = defaults
|
||||||
@@ -455,6 +453,17 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)
|
self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)
|
||||||
self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5)
|
self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5)
|
||||||
|
|
||||||
|
def test_layer_norm_dim_check(self):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
weight = mx.ones((129,))
|
||||||
|
x = mx.random.randint(low=0, high=10, shape=(4, 128))
|
||||||
|
mx.fast.layer_norm(x, weight, None, 1e-3)
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
bias = mx.ones((129,))
|
||||||
|
x = mx.random.randint(low=0, high=10, shape=(4, 128))
|
||||||
|
mx.fast.layer_norm(x, None, bias, 1e-3)
|
||||||
|
|
||||||
def test_layer_norm(self):
|
def test_layer_norm(self):
|
||||||
# Per dtype absolute tolerance
|
# Per dtype absolute tolerance
|
||||||
tolerances = {mx.float32: 1e-5, mx.float16: 5e-3, mx.bfloat16: 5e-2}
|
tolerances = {mx.float32: 1e-5, mx.float16: 5e-3, mx.bfloat16: 5e-2}
|
||||||
|
|||||||
Reference in New Issue
Block a user