mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-04 06:44:40 +08:00
RMS norm without scaling (#1915)
This commit is contained in:

committed by
GitHub

parent
5d68082881
commit
5e6c130d93
@@ -298,6 +298,9 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
rx = rms_norm(x, weight, eps)
|
||||
rx_fast = mx.fast.rms_norm(x, weight, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
rx = rms_norm(x, mx.ones_like(weight), eps)
|
||||
rx_fast = mx.fast.rms_norm(x, None, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
|
||||
for eps in epss:
|
||||
dtype, _, dims = defaults
|
||||
@@ -306,6 +309,9 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
rx = rms_norm(x, weight, eps)
|
||||
rx_fast = mx.fast.rms_norm(x, weight, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
rx = rms_norm(x, mx.ones_like(weight), eps)
|
||||
rx_fast = mx.fast.rms_norm(x, None, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
|
||||
for dims in dimss:
|
||||
dtype, eps, _ = defaults
|
||||
@@ -314,6 +320,9 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
rx = rms_norm(x, weight, eps)
|
||||
rx_fast = mx.fast.rms_norm(x, weight, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
rx = rms_norm(x, mx.ones_like(weight), eps)
|
||||
rx_fast = mx.fast.rms_norm(x, None, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
|
||||
# Test > 4096
|
||||
dims, dtype, eps = 4099, mx.float32, 1e-5
|
||||
@@ -333,6 +342,8 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
eps = 1e-5
|
||||
f1 = lambda x, w, y: (rms_norm(x, w, eps) * y).sum()
|
||||
f2 = lambda x, w, y: (mx.fast.rms_norm(x, w, eps) * y).sum()
|
||||
f3 = lambda x, y: (rms_norm(x, mx.ones((x.shape[-1],)), eps) * y).sum()
|
||||
f4 = lambda x, y: (mx.fast.rms_norm(x, None, eps) * y).sum()
|
||||
|
||||
x = mx.random.uniform(shape=(8, 100, D))
|
||||
w = mx.random.uniform(shape=(D,))
|
||||
@@ -341,6 +352,9 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
gx2, gw2 = mx.grad(f2, argnums=(0, 1))(x, w, y)
|
||||
self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)
|
||||
self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5)
|
||||
gx1 = mx.grad(f3, argnums=(0,))(x, y)
|
||||
gx2 = mx.grad(f4, argnums=(0,))(x, y)
|
||||
self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)
|
||||
|
||||
D = 8192
|
||||
x = mx.random.uniform(shape=(2, 2, D))
|
||||
@@ -350,6 +364,9 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
gx2, gw2 = mx.grad(f2, argnums=(0, 1))(x, w, y)
|
||||
self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)
|
||||
self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5)
|
||||
gx1 = mx.grad(f3, argnums=(0,))(x, y)
|
||||
gx2 = mx.grad(f4, argnums=(0,))(x, y)
|
||||
self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)
|
||||
|
||||
def gf(f):
|
||||
def inner(x, w, y):
|
||||
|
Reference in New Issue
Block a user