mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-10 21:37:50 +08:00
Adds mx.fast.layer_norm (#870)
This commit is contained in:

committed by
GitHub

parent
105d236889
commit
2225374060
@@ -166,6 +166,105 @@ class TestFast(mlx_tests.MLXTestCase):
|
||||
rx_fast = mx.fast.rms_norm(x, weight, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-6)
|
||||
|
||||
def test_layer_norm(self):
|
||||
def layer_norm(x, weight, bias, eps):
|
||||
ot = x.dtype
|
||||
x = x.astype(mx.float32)
|
||||
mean = x.mean(axis=-1, keepdims=True)
|
||||
var = x.var(axis=-1, keepdims=True)
|
||||
x = (x - mean) * mx.rsqrt(var + eps)
|
||||
x = x.astype(ot)
|
||||
if weight is not None:
|
||||
x = x * weight
|
||||
if bias is not None:
|
||||
x = x + bias
|
||||
return x
|
||||
|
||||
# Per dtype absolute tolerance
|
||||
tolerances = {mx.float32: 2e-6, mx.float16: 2e-3, mx.bfloat16: 2e-2}
|
||||
|
||||
dtypes = [mx.float32, mx.float16, mx.bfloat16]
|
||||
epss = [1e-3, 1e-5]
|
||||
dimss = [31, 32, 33]
|
||||
defaults = (mx.float32, 1e-5, 32)
|
||||
|
||||
for dtype in dtypes:
|
||||
_, eps, dims = defaults
|
||||
x = mx.random.uniform(
|
||||
shape=(
|
||||
2,
|
||||
dims,
|
||||
)
|
||||
).astype(dtype)
|
||||
weight = mx.random.uniform(shape=(dims,)).astype(dtype)
|
||||
bias = mx.random.uniform(shape=(dims,)).astype(dtype)
|
||||
rx = layer_norm(x, weight, bias, eps)
|
||||
rx_fast = mx.fast.layer_norm(x, weight, bias, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
rx = layer_norm(x, weight, None, eps)
|
||||
rx_fast = mx.fast.layer_norm(x, weight, None, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
rx = layer_norm(x, None, bias, eps)
|
||||
rx_fast = mx.fast.layer_norm(x, None, bias, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
rx = layer_norm(x, None, None, eps)
|
||||
rx_fast = mx.fast.layer_norm(x, None, None, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
|
||||
for eps in epss:
|
||||
dtype, _, dims = defaults
|
||||
x = mx.random.uniform(shape=(2, dims)).astype(dtype)
|
||||
weight = mx.random.uniform(shape=(dims,)).astype(dtype)
|
||||
bias = mx.random.uniform(shape=(dims,)).astype(dtype)
|
||||
rx = layer_norm(x, weight, bias, eps)
|
||||
rx_fast = mx.fast.layer_norm(x, weight, bias, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
rx = layer_norm(x, weight, None, eps)
|
||||
rx_fast = mx.fast.layer_norm(x, weight, None, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
rx = layer_norm(x, None, bias, eps)
|
||||
rx_fast = mx.fast.layer_norm(x, None, bias, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
rx = layer_norm(x, None, None, eps)
|
||||
rx_fast = mx.fast.layer_norm(x, None, None, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
|
||||
for dims in dimss:
|
||||
dtype, eps, _ = defaults
|
||||
x = mx.random.uniform(shape=(2, dims)).astype(dtype)
|
||||
weight = mx.random.uniform(shape=(dims,)).astype(dtype)
|
||||
bias = mx.random.uniform(shape=(dims,)).astype(dtype)
|
||||
rx = layer_norm(x, weight, bias, eps)
|
||||
rx_fast = mx.fast.layer_norm(x, weight, bias, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
rx = layer_norm(x, weight, None, eps)
|
||||
rx_fast = mx.fast.layer_norm(x, weight, None, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
rx = layer_norm(x, None, bias, eps)
|
||||
rx_fast = mx.fast.layer_norm(x, None, bias, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
rx = layer_norm(x, None, None, eps)
|
||||
rx_fast = mx.fast.layer_norm(x, None, None, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
|
||||
# Test > 4096
|
||||
dims, dtype, eps = 4099, mx.float32, 1e-5
|
||||
x = mx.random.uniform(shape=(dims,)).astype(dtype)
|
||||
weight = mx.random.uniform(shape=(dims,)).astype(dtype)
|
||||
bias = mx.random.uniform(shape=(dims,)).astype(dtype)
|
||||
rx = layer_norm(x, weight, bias, eps)
|
||||
rx_fast = mx.fast.layer_norm(x, weight, bias, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
rx = layer_norm(x, weight, None, eps)
|
||||
rx_fast = mx.fast.layer_norm(x, weight, None, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
rx = layer_norm(x, None, bias, eps)
|
||||
rx_fast = mx.fast.layer_norm(x, None, bias, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
rx = layer_norm(x, None, None, eps)
|
||||
rx_fast = mx.fast.layer_norm(x, None, None, eps)
|
||||
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
|
||||
|
||||
def test_fast_transforms(self):
|
||||
x = mx.random.uniform(shape=(2, 2, 8))
|
||||
|
||||
|
Reference in New Issue
Block a user