Adds mx.fast.layer_norm (#870)

This commit is contained in:
Angelos Katharopoulos
2024-03-21 13:55:51 -07:00
committed by GitHub
parent 105d236889
commit 2225374060
11 changed files with 600 additions and 8 deletions

View File

@@ -166,6 +166,105 @@ class TestFast(mlx_tests.MLXTestCase):
rx_fast = mx.fast.rms_norm(x, weight, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-6)
def test_layer_norm(self):
def layer_norm(x, weight, bias, eps):
ot = x.dtype
x = x.astype(mx.float32)
mean = x.mean(axis=-1, keepdims=True)
var = x.var(axis=-1, keepdims=True)
x = (x - mean) * mx.rsqrt(var + eps)
x = x.astype(ot)
if weight is not None:
x = x * weight
if bias is not None:
x = x + bias
return x
# Per dtype absolute tolerance
tolerances = {mx.float32: 2e-6, mx.float16: 2e-3, mx.bfloat16: 2e-2}
dtypes = [mx.float32, mx.float16, mx.bfloat16]
epss = [1e-3, 1e-5]
dimss = [31, 32, 33]
defaults = (mx.float32, 1e-5, 32)
for dtype in dtypes:
_, eps, dims = defaults
x = mx.random.uniform(
shape=(
2,
dims,
)
).astype(dtype)
weight = mx.random.uniform(shape=(dims,)).astype(dtype)
bias = mx.random.uniform(shape=(dims,)).astype(dtype)
rx = layer_norm(x, weight, bias, eps)
rx_fast = mx.fast.layer_norm(x, weight, bias, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
rx = layer_norm(x, weight, None, eps)
rx_fast = mx.fast.layer_norm(x, weight, None, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
rx = layer_norm(x, None, bias, eps)
rx_fast = mx.fast.layer_norm(x, None, bias, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
rx = layer_norm(x, None, None, eps)
rx_fast = mx.fast.layer_norm(x, None, None, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
for eps in epss:
dtype, _, dims = defaults
x = mx.random.uniform(shape=(2, dims)).astype(dtype)
weight = mx.random.uniform(shape=(dims,)).astype(dtype)
bias = mx.random.uniform(shape=(dims,)).astype(dtype)
rx = layer_norm(x, weight, bias, eps)
rx_fast = mx.fast.layer_norm(x, weight, bias, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
rx = layer_norm(x, weight, None, eps)
rx_fast = mx.fast.layer_norm(x, weight, None, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
rx = layer_norm(x, None, bias, eps)
rx_fast = mx.fast.layer_norm(x, None, bias, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
rx = layer_norm(x, None, None, eps)
rx_fast = mx.fast.layer_norm(x, None, None, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
for dims in dimss:
dtype, eps, _ = defaults
x = mx.random.uniform(shape=(2, dims)).astype(dtype)
weight = mx.random.uniform(shape=(dims,)).astype(dtype)
bias = mx.random.uniform(shape=(dims,)).astype(dtype)
rx = layer_norm(x, weight, bias, eps)
rx_fast = mx.fast.layer_norm(x, weight, bias, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
rx = layer_norm(x, weight, None, eps)
rx_fast = mx.fast.layer_norm(x, weight, None, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
rx = layer_norm(x, None, bias, eps)
rx_fast = mx.fast.layer_norm(x, None, bias, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
rx = layer_norm(x, None, None, eps)
rx_fast = mx.fast.layer_norm(x, None, None, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
# Test > 4096
dims, dtype, eps = 4099, mx.float32, 1e-5
x = mx.random.uniform(shape=(dims,)).astype(dtype)
weight = mx.random.uniform(shape=(dims,)).astype(dtype)
bias = mx.random.uniform(shape=(dims,)).astype(dtype)
rx = layer_norm(x, weight, bias, eps)
rx_fast = mx.fast.layer_norm(x, weight, bias, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
rx = layer_norm(x, weight, None, eps)
rx_fast = mx.fast.layer_norm(x, weight, None, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
rx = layer_norm(x, None, bias, eps)
rx_fast = mx.fast.layer_norm(x, None, bias, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
rx = layer_norm(x, None, None, eps)
rx_fast = mx.fast.layer_norm(x, None, None, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
def test_fast_transforms(self):
x = mx.random.uniform(shape=(2, 2, 8))