RMS norm without scaling (#1915)

This commit is contained in:
Angelos Katharopoulos
2025-02-28 20:26:57 -08:00
committed by GitHub
parent 5d68082881
commit 5e6c130d93
9 changed files with 220 additions and 101 deletions

View File

@@ -25,12 +25,12 @@ void init_fast(nb::module_& parent_module) {
"rms_norm",
&mx::fast::rms_norm,
"x"_a,
"weight"_a,
"weight"_a.none(),
"eps"_a,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def rms_norm(x: array, weight: array, eps: float, *, stream: Union[None, Stream, Device] = None) -> array"),
"def rms_norm(x: array, weight: Optional[array], eps: float, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Root Mean Square normalization (RMS norm).
@@ -38,9 +38,9 @@ void init_fast(nb::module_& parent_module) {
Args:
x (array): Input array.
weight (array): A multiplicative weight to scale the result by.
weight (array, optional): A multiplicative weight to scale the result by.
The ``weight`` should be one-dimensional with the same size
as the last axis of ``x``.
as the last axis of ``x``. If set to ``None`` then no scaling happens.
eps (float): A small additive constant for numerical stability.
Returns:

View File

@@ -298,6 +298,9 @@ class TestFast(mlx_tests.MLXTestCase):
rx = rms_norm(x, weight, eps)
rx_fast = mx.fast.rms_norm(x, weight, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
rx = rms_norm(x, mx.ones_like(weight), eps)
rx_fast = mx.fast.rms_norm(x, None, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
for eps in epss:
dtype, _, dims = defaults
@@ -306,6 +309,9 @@ class TestFast(mlx_tests.MLXTestCase):
rx = rms_norm(x, weight, eps)
rx_fast = mx.fast.rms_norm(x, weight, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
rx = rms_norm(x, mx.ones_like(weight), eps)
rx_fast = mx.fast.rms_norm(x, None, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
for dims in dimss:
dtype, eps, _ = defaults
@@ -314,6 +320,9 @@ class TestFast(mlx_tests.MLXTestCase):
rx = rms_norm(x, weight, eps)
rx_fast = mx.fast.rms_norm(x, weight, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
rx = rms_norm(x, mx.ones_like(weight), eps)
rx_fast = mx.fast.rms_norm(x, None, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
# Test > 4096
dims, dtype, eps = 4099, mx.float32, 1e-5
@@ -333,6 +342,8 @@ class TestFast(mlx_tests.MLXTestCase):
eps = 1e-5
f1 = lambda x, w, y: (rms_norm(x, w, eps) * y).sum()
f2 = lambda x, w, y: (mx.fast.rms_norm(x, w, eps) * y).sum()
f3 = lambda x, y: (rms_norm(x, mx.ones((x.shape[-1],)), eps) * y).sum()
f4 = lambda x, y: (mx.fast.rms_norm(x, None, eps) * y).sum()
x = mx.random.uniform(shape=(8, 100, D))
w = mx.random.uniform(shape=(D,))
@@ -341,6 +352,9 @@ class TestFast(mlx_tests.MLXTestCase):
gx2, gw2 = mx.grad(f2, argnums=(0, 1))(x, w, y)
self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)
self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5)
gx1 = mx.grad(f3, argnums=(0,))(x, y)
gx2 = mx.grad(f4, argnums=(0,))(x, y)
self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)
D = 8192
x = mx.random.uniform(shape=(2, 2, D))
@@ -350,6 +364,9 @@ class TestFast(mlx_tests.MLXTestCase):
gx2, gw2 = mx.grad(f2, argnums=(0, 1))(x, w, y)
self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)
self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5)
gx1 = mx.grad(f3, argnums=(0,))(x, y)
gx2 = mx.grad(f4, argnums=(0,))(x, y)
self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)
def gf(f):
def inner(x, w, y):