mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 10:26:56 +08:00
Adds mx.fast.layer_norm (#870)
This commit is contained in:

committed by
GitHub

parent
105d236889
commit
2225374060
@@ -85,13 +85,19 @@ class LayerNorm(Module):
|
||||
eps (float): A small additive constant for numerical stability
|
||||
affine (bool): If True learn an affine transform to apply after the
|
||||
normalization
|
||||
bias (bool): If True include a translation to the affine
|
||||
transformation. If set to False the transformation is not really affine
|
||||
just scaling.
|
||||
"""
|
||||
|
||||
def __init__(self, dims: int, eps: float = 1e-5, affine: bool = True):
|
||||
def __init__(
|
||||
self, dims: int, eps: float = 1e-5, affine: bool = True, bias: bool = True
|
||||
):
|
||||
super().__init__()
|
||||
if affine:
|
||||
self.bias = mx.zeros((dims,))
|
||||
self.weight = mx.ones((dims,))
|
||||
if bias:
|
||||
self.bias = mx.zeros((dims,))
|
||||
self.eps = eps
|
||||
self.dims = dims
|
||||
|
||||
@@ -99,10 +105,9 @@ class LayerNorm(Module):
|
||||
return f"{self.dims}, eps={self.eps}, affine={'weight' in self}"
|
||||
|
||||
def __call__(self, x):
|
||||
means = mx.mean(x, axis=-1, keepdims=True)
|
||||
var = mx.var(x, axis=-1, keepdims=True)
|
||||
x = (x - means) * mx.rsqrt(var + self.eps)
|
||||
return (self.weight * x + self.bias) if "weight" in self else x
|
||||
weight = self.weight if "weight" in self else None
|
||||
bias = self.bias if "bias" in self else None
|
||||
return mx.fast.layer_norm(x, weight, bias, self.eps)
|
||||
|
||||
|
||||
class RMSNorm(Module):
|
||||
|
Reference in New Issue
Block a user