mlx.nn.RMSNorm

class mlx.nn.RMSNorm(dims: int, eps: float = 1e-05)

Applies Root Mean Square normalization [1] to the inputs.

Computes

\[y = \frac{x}{\sqrt{E[x^2] + \epsilon}} \gamma\]

where \(\gamma\) is a learned per feature dimension parameter initialized at 1.

[1]: https://arxiv.org/abs/1910.07467

Parameters:
  • dims (int) – The feature dimension of the input to normalize over

  • eps (float) – A small additive constant for numerical stability