Adds mx.fast.layer_norm (#870)

This commit is contained in:
Angelos Katharopoulos
2024-03-21 13:55:51 -07:00
committed by GitHub
parent 105d236889
commit 2225374060
11 changed files with 600 additions and 8 deletions

View File

@@ -85,13 +85,19 @@ class LayerNorm(Module):
eps (float): A small additive constant for numerical stability
affine (bool): If True learn an affine transform to apply after the
normalization
bias (bool): If True include a translation to the affine
transformation. If set to False the transformation is not really affine
just scaling.
"""
def __init__(self, dims: int, eps: float = 1e-5, affine: bool = True):
def __init__(
self, dims: int, eps: float = 1e-5, affine: bool = True, bias: bool = True
):
super().__init__()
if affine:
self.bias = mx.zeros((dims,))
self.weight = mx.ones((dims,))
if bias:
self.bias = mx.zeros((dims,))
self.eps = eps
self.dims = dims
@@ -99,10 +105,9 @@ class LayerNorm(Module):
return f"{self.dims}, eps={self.eps}, affine={'weight' in self}"
def __call__(self, x):
means = mx.mean(x, axis=-1, keepdims=True)
var = mx.var(x, axis=-1, keepdims=True)
x = (x - means) * mx.rsqrt(var + self.eps)
return (self.weight * x + self.bias) if "weight" in self else x
weight = self.weight if "weight" in self else None
bias = self.bias if "bias" in self else None
return mx.fast.layer_norm(x, weight, bias, self.eps)
class RMSNorm(Module):