Implement vjps for some primitives in the fast namespace (#883)

* Implement rope vjp in terms of rope
* RMSNormVJP primitive and kernel
* Add LayerNormVJP primitive and kernel
This commit is contained in:
Angelos Katharopoulos
2024-03-26 16:35:34 -07:00
committed by GitHub
parent a789685c63
commit 29221fa238
14 changed files with 1383 additions and 110 deletions

View File

@@ -16,11 +16,14 @@ def rope_orig(x, dims, traditional, base, scale, offset):
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
costheta, sintheta = mx.cos(theta), mx.sin(theta)
if traditional:
x1 = x[..., ::2]
x2 = x[..., 1::2]
x1 = x[..., :dims:2]
x2 = x[..., 1:dims:2]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
rx = mx.concatenate([rx1[..., None], rx2[..., None]], axis=-1)
if dims < x.shape[-1]:
rx = mx.reshape(rx, (*x.shape[:-1], dims))
rx = mx.concatenate([rx, x[..., dims:]], axis=-1)
return mx.reshape(rx, x.shape)
else:
x1 = x[..., : dims // 2]
@@ -34,6 +37,26 @@ def rope_orig(x, dims, traditional, base, scale, offset):
return rx
def rms_norm(x, weight, eps):
x = x.astype(mx.float32)
x = x * mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
return weight * x.astype(weight.dtype)
def layer_norm(x, weight, bias, eps):
ot = x.dtype
x = x.astype(mx.float32)
mean = x.mean(axis=-1, keepdims=True)
var = x.var(axis=-1, keepdims=True)
x = (x - mean) * mx.rsqrt(var + eps)
x = x.astype(ot)
if weight is not None:
x = x * weight
if bias is not None:
x = x + bias
return x
class TestFast(mlx_tests.MLXTestCase):
def test_rope(self):
T = 4
@@ -115,12 +138,34 @@ class TestFast(mlx_tests.MLXTestCase):
)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
def test_rms_norm(self):
def rms_norm(x, weight, eps):
x = x.astype(mx.float32)
x = x * mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
return weight * x.astype(weight.dtype)
def test_rope_grad(self):
D = 32
defaults = (D, 10000.0, 1.0, 0, False)
for dims in (D, D // 2):
for traditional in (True, False):
_, base, scale, offset, _ = defaults
f1 = lambda x, y: (
rope_orig(x, dims, traditional, base, scale, offset) * y
).sum()
f2 = lambda x, y: (
mx.fast.rope(
x,
dims,
traditional=traditional,
base=base,
scale=scale,
offset=offset,
)
* y
).sum()
x = mx.random.uniform(shape=(2, 100, D))
y = mx.random.uniform(shape=(2, 100, D))
g1 = mx.grad(f1)(x, y)
g2 = mx.grad(f2)(x, y)
self.assertLess(mx.abs(g1 - g2).max(), 1e-5)
def test_rms_norm(self):
# Per dtype absolute tolerance
tolerances = {mx.float32: 1e-6, mx.float16: 1e-3, mx.bfloat16: 1e-2}
@@ -166,20 +211,42 @@ class TestFast(mlx_tests.MLXTestCase):
rx_fast = mx.fast.rms_norm(x, weight, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-6)
def test_layer_norm(self):
def layer_norm(x, weight, bias, eps):
ot = x.dtype
x = x.astype(mx.float32)
mean = x.mean(axis=-1, keepdims=True)
var = x.var(axis=-1, keepdims=True)
x = (x - mean) * mx.rsqrt(var + eps)
x = x.astype(ot)
if weight is not None:
x = x * weight
if bias is not None:
x = x + bias
return x
def test_rms_norm_grad(self):
D = 32
eps = 1e-5
f1 = lambda x, w, y: (rms_norm(x, w, eps) * y).sum()
f2 = lambda x, w, y: (mx.fast.rms_norm(x, w, eps) * y).sum()
x = mx.random.uniform(shape=(8, 100, D))
w = mx.random.uniform(shape=(D,))
y = mx.random.uniform(shape=(8, 100, D))
gx1, gw1 = mx.grad(f1, argnums=(0, 1))(x, w, y)
gx2, gw2 = mx.grad(f2, argnums=(0, 1))(x, w, y)
self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)
self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5)
D = 8192
x = mx.random.uniform(shape=(2, 2, D))
w = mx.random.uniform(shape=(D,))
y = mx.random.uniform(shape=(2, 2, D))
gx1, gw1 = mx.grad(f1, argnums=(0, 1))(x, w, y)
gx2, gw2 = mx.grad(f2, argnums=(0, 1))(x, w, y)
self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)
self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5)
def gf(f):
def inner(x, w, y):
gx, gw = mx.grad(f, argnums=(0, 1))(x, w, y)
return (gx + gw).sum()
return inner
gx1, gw1 = mx.grad(gf(f1), argnums=(0, 1))(x, w, y)
gx2, gw2 = mx.grad(gf(f2), argnums=(0, 1))(x, w, y)
self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)
self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5)
def test_layer_norm(self):
# Per dtype absolute tolerance
tolerances = {mx.float32: 3e-6, mx.float16: 3e-3, mx.bfloat16: 3e-2}
@@ -265,6 +332,49 @@ class TestFast(mlx_tests.MLXTestCase):
rx_fast = mx.fast.layer_norm(x, None, None, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
def test_layer_norm_grad(self):
D = 32
eps = 1e-5
f1 = lambda x, w, b, y: (layer_norm(x, w, b, eps) * y).sum()
f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, eps) * y).sum()
x = mx.random.uniform(shape=(8, 100, D))
w = mx.random.uniform(shape=(D,))
b = mx.random.uniform(shape=(D,))
y = mx.random.uniform(shape=(8, 100, D))
gx1, gw1, gb1 = mx.grad(f1, argnums=(0, 1, 2))(x, w, b, y)
gx2, gw2, gb2 = mx.grad(f2, argnums=(0, 1, 2))(x, w, b, y)
self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)
self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5)
self.assertLess(mx.abs(gb1 - gb2).max() / mx.abs(gb1).mean(), 1e-5)
D = 8192
x = mx.random.uniform(shape=(8, 100, D))
w = mx.random.uniform(shape=(D,))
b = mx.random.uniform(shape=(D,))
y = mx.random.uniform(shape=(8, 100, D))
gx1, gw1, gb1 = mx.grad(f1, argnums=(0, 1, 2))(x, w, b, y)
gx2, gw2, gb2 = mx.grad(f2, argnums=(0, 1, 2))(x, w, b, y)
self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)
self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5)
self.assertLess(mx.abs(gb1 - gb2).max() / mx.abs(gb1).mean(), 1e-5)
def gf(f):
def inner(x, w, b, y):
gx, gw, gb = mx.grad(f, argnums=(0, 1, 2))(x, w, b, y)
return ((gx + gw + gb) * y).sum()
return inner
gx1, gw1, gb1 = mx.grad(gf(f1), argnums=(0, 1, 2))(x, w, b, y)
gx2, gw2, gb2 = mx.grad(gf(f2), argnums=(0, 1, 2))(x, w, b, y)
self.assertLess(mx.abs(gx1 - gx2).max() / mx.abs(gx1).mean(), 1e-5)
self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5)
self.assertLess(mx.abs(gb1).max(), 1e-9)
self.assertLess(mx.abs(gb2).max(), 1e-9)
def test_fast_transforms(self):
x = mx.random.uniform(shape=(2, 2, 8))