diff --git a/python/mlx/nn/layers/normalization.py b/python/mlx/nn/layers/normalization.py index 6de377cda..5807d3d7a 100644 --- a/python/mlx/nn/layers/normalization.py +++ b/python/mlx/nn/layers/normalization.py @@ -1,5 +1,7 @@ # Copyright © 2023 Apple Inc. +from typing import Tuple + import mlx.core as mx from mlx.nn.layers.base import Module @@ -97,7 +99,7 @@ class GroupNorm(Module): where :math:`\gamma` and :math:`\beta` are learned per feature dimension parameters initialized at 1 and 0 respectively. However, the mean and variance are computed over the spatial dimensions and each group of - features. In particular, the input is split into num_groups across the + features. In particular, the input is split into num_groups accross the feature dimension. The feature dimension is assumed to be the last dimension and the dimensions @@ -178,3 +180,95 @@ class GroupNorm(Module): ) x = group_norm(x) return (self.weight * x + self.bias) if "weight" in self else x + + +class BatchNorm1d(Module): + r"""Applies Batch Normalization [1] to the inputs. + + Computes + + .. math:: + + y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta, + + where :math:`\gamma` and :math:`\beta` are learned per feature dimension + parameters initialized at 1 and 0 respectively. + + [1]: https://arxiv.org/abs/1502.03167 + + Args: + num_features (int): The feature dimension of the input to normalize over. + eps (float, optional): A small additive constant for numerical stability. Default is 1e-5. + momentum (float, optional): The momentum for updating the running mean and variance. Default is 0.1. + affine (bool, optional): If True, learn an affine transform to apply after the normalization. Default is True. + + Examples: + >>> import mlx.core as mx + >>> import mlx.nn as nn + + >>> # With Learnable Parameters + >>> m = nn.BatchNorm1d(100) + >>> # Without Learnable Parameters + >>> m = nn.BatchNorm1d(4, affine=False) + >>> input = mx.random.normal(20, 4) + >>> output = m(input) + + """ + + def __init__( + self, + num_features: int, + eps: float = 1e-5, + momentum: float = 0.1, + affine: bool = True, + ): + super().__init__() + if affine: + self.bias = mx.zeros((num_features,)) + self.weight = mx.ones((num_features,)) + + self.num_features = num_features + self.eps = eps + self.momentum = momentum + self.running_mean = mx.zeros((num_features,)) + self.running_var = mx.ones((num_features,)) + + def _extra_repr(self): + return f"num_features={self.num_features}, eps={self.eps}, momentum={self.momentum}, affine={'weight' in self}" + + def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]: + """ + Calculate the mean and variance of the input tensor. + + Args: + x (mx.array): Input tensor. + + Returns: + tuple: Tuple containing mean and variance. + """ + means = mx.mean(x, axis=0, keepdims=True) + var = mx.var(x, axis=0, keepdims=True) + self.running_mean = ( + self.momentum * self.running_mean + (1 - self.momentum) * means + ) + self.running_var = self.momentum * self.running_var + (1 - self.momentum) * var + return means, var + + def __call__(self, x: mx.array): + """ + Forward pass of BatchNorm1d. + + Args: + x (mx.array): Input tensor. + + Returns: + mx.array: Output tensor. + """ + if x.ndim != 2: + raise ValueError("BatchNorm1d only supports 2D inputs") + + means, var = self.running_mean, self.running_var + if self.training: + means, var = self._calc_stats(x) + x = (x - means) * mx.rsqrt(var + self.eps) + return (self.weight * x + self.bias) if "weight" in self else x