mlx.nn.BatchNorm#
- class BatchNorm(num_features: int, eps: float = 1e-05, momentum: float = 0.1, affine: bool = True, track_running_stats: bool = True)#
Applies Batch Normalization over a 2D or 3D input.
Computes
\[y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta,\]where \(\gamma\) and \(\beta\) are learned per feature dimension parameters initialized at 1 and 0 respectively.
The input shape is specified as
NCorNLC, whereNis the batch,Cis the number of features or channels, andLis the sequence length. The output has the same shape as the input. For four-dimensional arrays, the shape isNHWC, whereHandWare the height and width respectively.For more information on Batch Normalization, see the original paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.
- Parameters:
num_features (int) – The feature dimension to normalize over.
eps (float, optional) – A small additive constant for numerical stability. Default:
1e-5.momentum (float, optional) – The momentum for updating the running mean and variance. Default:
0.1.affine (bool, optional) – If
True, apply a learned affine transformation after the normalization. Default:True.track_running_stats (bool, optional) – If
True, track the running mean and variance. Default:True.
Examples
>>> import mlx.core as mx >>> import mlx.nn as nn >>> x = mx.random.normal((5, 4)) >>> bn = nn.BatchNorm(num_features=4, affine=True) >>> output = bn(x)