mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Implement vjps for some primitives in the fast namespace (#883)
* Implement rope vjp in terms of rope * RMSNormVJP primitive and kernel * Add LayerNormVJP primitive and kernel
This commit is contained in:

committed by
GitHub

parent
a789685c63
commit
29221fa238
@@ -16,11 +16,14 @@ def rope_orig(x, dims, traditional, base, scale, offset):
|
||||
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
|
||||
costheta, sintheta = mx.cos(theta), mx.sin(theta)
|
||||
if traditional:
|
||||
x1 = x[..., ::2]
|
||||
x2 = x[..., 1::2]
|
||||
x1 = x[..., :dims:2]
|
||||
x2 = x[..., 1:dims:2]
|
||||
rx1 = x1 * costheta - x2 * sintheta
|
||||
rx2 = x1 * sintheta + x2 * costheta
|
||||
rx = mx.concatenate([rx1[..., None], rx2[..., None]], axis=-1)
|
||||
if dims < x.shape[-1]:
|
||||
rx = mx.reshape(rx, (*x.shape[:-1], dims))
|
||||
rx = mx.concatenate([rx, x[..., dims:]], axis=-1)
|
||||
return mx.reshape(rx, x.shape)
|
||||
else:
|
||||
x1 = x[..., : dims // 2]
|
||||
@@ -34,6 +37,26 @@ def rope_orig(x, dims, traditional, base, scale, offset):
|
||||
return rx
|
||||
|
||||
|
||||
def rms_norm(x, weight, eps):
|
||||
x = x.astype(mx.float32)
|
||||
x = x * mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
|
||||
return weight * x.astype(weight.dtype)
|
||||
|
||||
|
||||
def layer_norm(x, weight, bias, eps):
|
||||
ot = x.dtype
|
||||
x = x.astype(mx.float32)
|
||||
mean = x.mean(axis=-1, keepdims=True)
|
||||
var = x.var(axis=-1, keepdims=True)
|
||||
x = (x - mean) * mx.rsqrt(var + eps)
|
||||
x = x.astype(ot)
|
||||
if weight is not None:
|
||||
x = x * weight
|
||||
if bias is not None:
|
||||
x = x + bias
|
||||
return x
|
||||
|
||||
|
||||
class TestFast(mlx_tests.MLXTestCase):
|
||||
def test_rope(self):
|
||||
T = 4
|
||||
@@ -115,12 +138,34 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
|
||||
def test_rms_norm(self):
|
||||
def rms_norm(x, weight, eps):
|
||||
x = x.astype(mx.float32)
|
||||
x = x * mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
|
||||
return weight * x.astype(weight.dtype)
|
||||
def test_rope_grad(self):
|
||||
D = 32
|
||||
defaults = (D, 10000.0, 1.0, 0, False)
|
||||
for dims in (D, D // 2):
|
||||
for traditional in (True, False):
|
||||
_, base, scale, offset, _ = defaults
|
||||
f1 = lambda x, y: (
|
||||
rope_orig(x, dims, traditional, base, scale, offset) * y
|
||||
).sum()
|
||||
f2 = lambda x, y: (
|
||||
mx.fast.rope(
|
||||
x,
|
||||
dims,
|
||||
traditional=traditional,
|
||||
base=base,
|
||||
scale=scale,
|
||||
offset=offset,
|
||||
)
|
||||
* y
|
||||
).sum()
|
||||
|
||||
x = mx.random.uniform(shape=(2, 100, D))
|
||||
y = mx.random.uniform(shape=(2, 100, D))
|
||||
g1 = mx.grad(f1)(x, y)
|
||||
g2 = mx.grad(f2)(x, y)
|
||||
self.assertLess(mx.abs(g1 - g2).max(), 1e-5)
|
||||
|
||||
def test_rms_norm(self):
|
||||
# Per dtype absolute tolerance
|
||||
tolerances = {mx.float32: 1e-6, mx.float16: 1e-3, mx.bfloat16: 1e-2}
|
||||
|
||||
@@ -166,20 +211,42 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
rx_fast = mx.fast.rms_norm(x, weight, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-6)
|
||||
|
||||
def test_layer_norm(self):
|
||||
def layer_norm(x, weight, bias, eps):
|
||||
ot = x.dtype
|
||||
x = x.astype(mx.float32)
|
||||
mean = x.mean(axis=-1, keepdims=True)
|
||||
var = x.var(axis=-1, keepdims=True)
|
||||
x = (x - mean) * mx.rsqrt(var + eps)
|
||||
x = x.astype(ot)
|
||||
if weight is not None:
|
||||
x = x * weight
|
||||
if bias is not None:
|
||||
x = x + bias
|
||||
return x
|
||||
def test_rms_norm_grad(self):
|
||||
D = 32
|
||||
eps = 1e-5
|
||||
f1 = lambda x, w, y: (rms_norm(x, w, eps) * y).sum()
|
||||
f2 = lambda x, w, y: (mx.fast.rms_norm(x, w, eps) * y).sum()
|
||||
|
||||
x = mx.random.uniform(shape=(8, 100, D))
|
||||
w = mx.random.uniform(shape=(D,))
|
||||
y = mx.random.uniform(shape=(8, 100, D))
|
||||
gx1, gw1 = mx.grad(f1, argnums=(0, 1))(x, w, y)
|
||||
gx2, gw2 = mx.grad(f2, argnums=(0, 1))(x, w, y)
|
||||
self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)
|
||||
self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5)
|
||||
|
||||
D = 8192
|
||||
x = mx.random.uniform(shape=(2, 2, D))
|
||||
w = mx.random.uniform(shape=(D,))
|
||||
y = mx.random.uniform(shape=(2, 2, D))
|
||||
gx1, gw1 = mx.grad(f1, argnums=(0, 1))(x, w, y)
|
||||
gx2, gw2 = mx.grad(f2, argnums=(0, 1))(x, w, y)
|
||||
self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)
|
||||
self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5)
|
||||
|
||||
def gf(f):
|
||||
def inner(x, w, y):
|
||||
gx, gw = mx.grad(f, argnums=(0, 1))(x, w, y)
|
||||
return (gx + gw).sum()
|
||||
|
||||
return inner
|
||||
|
||||
gx1, gw1 = mx.grad(gf(f1), argnums=(0, 1))(x, w, y)
|
||||
gx2, gw2 = mx.grad(gf(f2), argnums=(0, 1))(x, w, y)
|
||||
self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)
|
||||
self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5)
|
||||
|
||||
def test_layer_norm(self):
|
||||
# Per dtype absolute tolerance
|
||||
tolerances = {mx.float32: 3e-6, mx.float16: 3e-3, mx.bfloat16: 3e-2}
|
||||
|
||||
@@ -265,6 +332,49 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
rx_fast = mx.fast.layer_norm(x, None, None, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
|
||||
def test_layer_norm_grad(self):
|
||||
D = 32
|
||||
eps = 1e-5
|
||||
f1 = lambda x, w, b, y: (layer_norm(x, w, b, eps) * y).sum()
|
||||
f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, eps) * y).sum()
|
||||
|
||||
x = mx.random.uniform(shape=(8, 100, D))
|
||||
w = mx.random.uniform(shape=(D,))
|
||||
b = mx.random.uniform(shape=(D,))
|
||||
y = mx.random.uniform(shape=(8, 100, D))
|
||||
|
||||
gx1, gw1, gb1 = mx.grad(f1, argnums=(0, 1, 2))(x, w, b, y)
|
||||
gx2, gw2, gb2 = mx.grad(f2, argnums=(0, 1, 2))(x, w, b, y)
|
||||
self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)
|
||||
self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5)
|
||||
self.assertLess(mx.abs(gb1 - gb2).max() / mx.abs(gb1).mean(), 1e-5)
|
||||
|
||||
D = 8192
|
||||
x = mx.random.uniform(shape=(8, 100, D))
|
||||
w = mx.random.uniform(shape=(D,))
|
||||
b = mx.random.uniform(shape=(D,))
|
||||
y = mx.random.uniform(shape=(8, 100, D))
|
||||
|
||||
gx1, gw1, gb1 = mx.grad(f1, argnums=(0, 1, 2))(x, w, b, y)
|
||||
gx2, gw2, gb2 = mx.grad(f2, argnums=(0, 1, 2))(x, w, b, y)
|
||||
self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)
|
||||
self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5)
|
||||
self.assertLess(mx.abs(gb1 - gb2).max() / mx.abs(gb1).mean(), 1e-5)
|
||||
|
||||
def gf(f):
|
||||
def inner(x, w, b, y):
|
||||
gx, gw, gb = mx.grad(f, argnums=(0, 1, 2))(x, w, b, y)
|
||||
return ((gx + gw + gb) * y).sum()
|
||||
|
||||
return inner
|
||||
|
||||
gx1, gw1, gb1 = mx.grad(gf(f1), argnums=(0, 1, 2))(x, w, b, y)
|
||||
gx2, gw2, gb2 = mx.grad(gf(f2), argnums=(0, 1, 2))(x, w, b, y)
|
||||
self.assertLess(mx.abs(gx1 - gx2).max() / mx.abs(gx1).mean(), 1e-5)
|
||||
self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5)
|
||||
self.assertLess(mx.abs(gb1).max(), 1e-9)
|
||||
self.assertLess(mx.abs(gb2).max(), 1e-9)
|
||||
|
||||
def test_fast_transforms(self):
|
||||
x = mx.random.uniform(shape=(2, 2, 8))
|
||||
|
||||
|
Reference in New Issue
Block a user