Adds mx.fast.layer_norm (#870)

This commit is contained in:
Angelos Katharopoulos
2024-03-21 13:55:51 -07:00
committed by GitHub
parent 105d236889
commit 2225374060
11 changed files with 600 additions and 8 deletions

View File

@@ -85,13 +85,19 @@ class LayerNorm(Module):
eps (float): A small additive constant for numerical stability
affine (bool): If True learn an affine transform to apply after the
normalization
bias (bool): If True include a translation to the affine
transformation. If set to False the transformation is not really affine
just scaling.
"""
def __init__(self, dims: int, eps: float = 1e-5, affine: bool = True):
def __init__(
self, dims: int, eps: float = 1e-5, affine: bool = True, bias: bool = True
):
super().__init__()
if affine:
self.bias = mx.zeros((dims,))
self.weight = mx.ones((dims,))
if bias:
self.bias = mx.zeros((dims,))
self.eps = eps
self.dims = dims
@@ -99,10 +105,9 @@ class LayerNorm(Module):
return f"{self.dims}, eps={self.eps}, affine={'weight' in self}"
def __call__(self, x):
means = mx.mean(x, axis=-1, keepdims=True)
var = mx.var(x, axis=-1, keepdims=True)
x = (x - means) * mx.rsqrt(var + self.eps)
return (self.weight * x + self.bias) if "weight" in self else x
weight = self.weight if "weight" in self else None
bias = self.bias if "bias" in self else None
return mx.fast.layer_norm(x, weight, bias, self.eps)
class RMSNorm(Module):

View File

@@ -46,6 +46,42 @@ void init_fast(nb::module_& parent_module) {
array: The output array.
)pbdoc");
m.def(
"layer_norm",
[](const array& x,
const std::optional<array>& weight,
const std::optional<array>& bias,
float eps,
const StreamOrDevice& s /* = {} */) {
return fast::layer_norm(x, weight, bias, eps, s);
},
"x"_a,
"weight"_a.none(),
"bias"_a.none(),
"eps"_a,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def layer_norm(x: array, weight: Optional[array], bias: Optional[array], eps: float, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Layer normalization.
The normalization is with respect to the last axis of the input ``x``.
Args:
x (array): Input array.
weight (array, optional): A multiplicative weight to scale the result by.
The ``weight`` should be one-dimensional with the same size
as the last axis of ``x``. If set to ``None`` then no scaling happens.
bias (array, optional): An additive offset to be added to the result.
The ``bias`` should be one-dimensional with the same size
as the last axis of ``x``. If set to ``None`` then no translation happens.
eps (float): A small additive constant for numerical stability.
Returns:
array: The output array.
)pbdoc");
m.def(
"rope",
[](const array& a,

View File

@@ -166,6 +166,105 @@ class TestFast(mlx_tests.MLXTestCase):
rx_fast = mx.fast.rms_norm(x, weight, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), 1e-6)
def test_layer_norm(self):
def layer_norm(x, weight, bias, eps):
ot = x.dtype
x = x.astype(mx.float32)
mean = x.mean(axis=-1, keepdims=True)
var = x.var(axis=-1, keepdims=True)
x = (x - mean) * mx.rsqrt(var + eps)
x = x.astype(ot)
if weight is not None:
x = x * weight
if bias is not None:
x = x + bias
return x
# Per dtype absolute tolerance
tolerances = {mx.float32: 2e-6, mx.float16: 2e-3, mx.bfloat16: 2e-2}
dtypes = [mx.float32, mx.float16, mx.bfloat16]
epss = [1e-3, 1e-5]
dimss = [31, 32, 33]
defaults = (mx.float32, 1e-5, 32)
for dtype in dtypes:
_, eps, dims = defaults
x = mx.random.uniform(
shape=(
2,
dims,
)
).astype(dtype)
weight = mx.random.uniform(shape=(dims,)).astype(dtype)
bias = mx.random.uniform(shape=(dims,)).astype(dtype)
rx = layer_norm(x, weight, bias, eps)
rx_fast = mx.fast.layer_norm(x, weight, bias, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
rx = layer_norm(x, weight, None, eps)
rx_fast = mx.fast.layer_norm(x, weight, None, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
rx = layer_norm(x, None, bias, eps)
rx_fast = mx.fast.layer_norm(x, None, bias, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
rx = layer_norm(x, None, None, eps)
rx_fast = mx.fast.layer_norm(x, None, None, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
for eps in epss:
dtype, _, dims = defaults
x = mx.random.uniform(shape=(2, dims)).astype(dtype)
weight = mx.random.uniform(shape=(dims,)).astype(dtype)
bias = mx.random.uniform(shape=(dims,)).astype(dtype)
rx = layer_norm(x, weight, bias, eps)
rx_fast = mx.fast.layer_norm(x, weight, bias, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
rx = layer_norm(x, weight, None, eps)
rx_fast = mx.fast.layer_norm(x, weight, None, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
rx = layer_norm(x, None, bias, eps)
rx_fast = mx.fast.layer_norm(x, None, bias, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
rx = layer_norm(x, None, None, eps)
rx_fast = mx.fast.layer_norm(x, None, None, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
for dims in dimss:
dtype, eps, _ = defaults
x = mx.random.uniform(shape=(2, dims)).astype(dtype)
weight = mx.random.uniform(shape=(dims,)).astype(dtype)
bias = mx.random.uniform(shape=(dims,)).astype(dtype)
rx = layer_norm(x, weight, bias, eps)
rx_fast = mx.fast.layer_norm(x, weight, bias, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
rx = layer_norm(x, weight, None, eps)
rx_fast = mx.fast.layer_norm(x, weight, None, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
rx = layer_norm(x, None, bias, eps)
rx_fast = mx.fast.layer_norm(x, None, bias, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
rx = layer_norm(x, None, None, eps)
rx_fast = mx.fast.layer_norm(x, None, None, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
# Test > 4096
dims, dtype, eps = 4099, mx.float32, 1e-5
x = mx.random.uniform(shape=(dims,)).astype(dtype)
weight = mx.random.uniform(shape=(dims,)).astype(dtype)
bias = mx.random.uniform(shape=(dims,)).astype(dtype)
rx = layer_norm(x, weight, bias, eps)
rx_fast = mx.fast.layer_norm(x, weight, bias, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
rx = layer_norm(x, weight, None, eps)
rx_fast = mx.fast.layer_norm(x, weight, None, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
rx = layer_norm(x, None, bias, eps)
rx_fast = mx.fast.layer_norm(x, None, bias, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
rx = layer_norm(x, None, None, eps)
rx_fast = mx.fast.layer_norm(x, None, None, eps)
self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype])
def test_fast_transforms(self):
x = mx.random.uniform(shape=(2, 2, 8))