implemented batchnorm layer

This commit is contained in:
m0saan 2023-12-18 23:26:37 +01:00
parent 22fee5a383
commit 2b617b63bd

View File

@ -1,5 +1,7 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
from typing import Tuple
import mlx.core as mx import mlx.core as mx
from mlx.nn.layers.base import Module from mlx.nn.layers.base import Module
@ -97,7 +99,7 @@ class GroupNorm(Module):
where :math:`\gamma` and :math:`\beta` are learned per feature dimension where :math:`\gamma` and :math:`\beta` are learned per feature dimension
parameters initialized at 1 and 0 respectively. However, the mean and parameters initialized at 1 and 0 respectively. However, the mean and
variance are computed over the spatial dimensions and each group of variance are computed over the spatial dimensions and each group of
features. In particular, the input is split into num_groups across the features. In particular, the input is split into num_groups accross the
feature dimension. feature dimension.
The feature dimension is assumed to be the last dimension and the dimensions The feature dimension is assumed to be the last dimension and the dimensions
@ -178,3 +180,95 @@ class GroupNorm(Module):
) )
x = group_norm(x) x = group_norm(x)
return (self.weight * x + self.bias) if "weight" in self else x return (self.weight * x + self.bias) if "weight" in self else x
class BatchNorm1d(Module):
r"""Applies Batch Normalization [1] to the inputs.
Computes
.. math::
y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta,
where :math:`\gamma` and :math:`\beta` are learned per feature dimension
parameters initialized at 1 and 0 respectively.
[1]: https://arxiv.org/abs/1502.03167
Args:
num_features (int): The feature dimension of the input to normalize over.
eps (float, optional): A small additive constant for numerical stability. Default is 1e-5.
momentum (float, optional): The momentum for updating the running mean and variance. Default is 0.1.
affine (bool, optional): If True, learn an affine transform to apply after the normalization. Default is True.
Examples:
>>> import mlx.core as mx
>>> import mlx.nn as nn
>>> # With Learnable Parameters
>>> m = nn.BatchNorm1d(100)
>>> # Without Learnable Parameters
>>> m = nn.BatchNorm1d(4, affine=False)
>>> input = mx.random.normal(20, 4)
>>> output = m(input)
"""
def __init__(
self,
num_features: int,
eps: float = 1e-5,
momentum: float = 0.1,
affine: bool = True,
):
super().__init__()
if affine:
self.bias = mx.zeros((num_features,))
self.weight = mx.ones((num_features,))
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.running_mean = mx.zeros((num_features,))
self.running_var = mx.ones((num_features,))
def _extra_repr(self):
return f"num_features={self.num_features}, eps={self.eps}, momentum={self.momentum}, affine={'weight' in self}"
def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]:
"""
Calculate the mean and variance of the input tensor.
Args:
x (mx.array): Input tensor.
Returns:
tuple: Tuple containing mean and variance.
"""
means = mx.mean(x, axis=0, keepdims=True)
var = mx.var(x, axis=0, keepdims=True)
self.running_mean = (
self.momentum * self.running_mean + (1 - self.momentum) * means
)
self.running_var = self.momentum * self.running_var + (1 - self.momentum) * var
return means, var
def __call__(self, x: mx.array):
"""
Forward pass of BatchNorm1d.
Args:
x (mx.array): Input tensor.
Returns:
mx.array: Output tensor.
"""
if x.ndim != 2:
raise ValueError("BatchNorm1d only supports 2D inputs")
means, var = self.running_mean, self.running_var
if self.training:
means, var = self._calc_stats(x)
x = (x - means) * mx.rsqrt(var + self.eps)
return (self.weight * x + self.bias) if "weight" in self else x