implement-batch-norm-layer (#217)

- Add batch normalization layer

---------

Co-authored-by: Robert McCraith <mccraithrobert@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
__mo_san__ 2023-12-25 16:32:53 +01:00 committed by GitHub
parent 9e6b8c9f48
commit a123c3c7d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 267 additions and 11 deletions

View File

@ -20,6 +20,7 @@ Layers
Linear
Conv1d
Conv2d
BatchNorm
LayerNorm
RMSNorm
GroupNorm

View File

@ -36,7 +36,7 @@ from mlx.nn.layers.convolution import Conv1d, Conv2d
from mlx.nn.layers.dropout import Dropout, Dropout2d
from mlx.nn.layers.embedding import Embedding
from mlx.nn.layers.linear import Linear
from mlx.nn.layers.normalization import GroupNorm, LayerNorm, RMSNorm
from mlx.nn.layers.normalization import BatchNorm, GroupNorm, LayerNorm, RMSNorm
from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding
from mlx.nn.layers.quantized import QuantizedLinear
from mlx.nn.layers.transformer import (

View File

@ -5,7 +5,7 @@ from mlx.nn.layers.base import Module
class Dropout(Module):
"""Randomly zero a portion of the elements during training.
r"""Randomly zero a portion of the elements during training.
The remaining elements are multiplied with :math:`\frac{1}{1-p}` where
:math:`p` is the probability of zeroing an element. This is done so the
@ -36,15 +36,13 @@ class Dropout(Module):
class Dropout2d(Module):
"""Apply 2D channel-wise dropout during training.
r"""Apply 2D channel-wise dropout during training.
Randomly zero out entire channels independently with probability :math:`p`.
This layer expects the channels to be last, i.e. the input shape should be
``NWHC`` or ``WHC`` where:
- ``N`` is the batch dimension
- ``H`` is the input image height
- ``W`` is the input image width
- ``C`` is the number of input channels
``NWHC`` or ``WHC`` where:``N`` is the batch dimension,``H`` is the input
image height,``W`` is the input image width, and``C`` is the number of
input channels
The remaining channels are scaled by :math:`\frac{1}{1-p}` to
maintain the expected value of each element. Unlike traditional dropout,

View File

@ -1,5 +1,7 @@
# Copyright © 2023 Apple Inc.
from typing import Tuple
import mlx.core as mx
from mlx.nn.layers.base import Module
@ -178,3 +180,121 @@ class GroupNorm(Module):
)
x = group_norm(x)
return (self.weight * x + self.bias) if "weight" in self else x
class BatchNorm(Module):
r"""Applies Batch Normalization over a 2D or 3D input.
Computes
.. math::
y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta,
where :math:`\gamma` and :math:`\beta` are learned per feature dimension
parameters initialized at 1 and 0 respectively.
The input shape is specified as ``NC`` or ``NLC``, where ``N`` is the
batch, ``C`` is the number of features or channels, and ``L`` is the
sequence length. The output has the same shape as the input. For
four-dimensional arrays, the shape is ``NHWC``, where ``H`` and ``W`` are
the height and width respecitvely.
For more information on Batch Normalization, see the original paper `Batch
Normalization: Accelerating Deep Network Training by Reducing Internal
Covariate Shift <https://arxiv.org/abs/1502.03167>`_.
Args:
num_features (int): The feature dimension to normalize over.
eps (float, optional): A small additive constant for numerical
stability. Default: ``1e-5``.
momentum (float, optional): The momentum for updating the running
mean and variance. Default: ``0.1``.
affine (bool, optional): If ``True``, apply a learned affine
transformation after the normalization. Default: ``True``.
track_running_stats (bool, optional): If ``True``, track the
running mean and variance. Default: ``True``.
Examples:
>>> import mlx.core as mx
>>> import mlx.nn as nn
>>> x = mx.random.normal((5, 4))
>>> bn = nn.BatchNorm(num_features=4, affine=True)
>>> output = bn(x)
"""
def __init__(
self,
num_features: int,
eps: float = 1e-5,
momentum: float = 0.1,
affine: bool = True,
track_running_stats: bool = True,
):
super().__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.track_running_stats = track_running_stats
if affine:
self.weight = mx.ones((num_features,))
self.bias = mx.zeros((num_features,))
if self.track_running_stats:
self._running_mean = mx.zeros((num_features,))
self._running_var = mx.ones((num_features,))
def _extra_repr(self):
return (
f"{self.num_features}, eps={self.eps}, "
f"momentum={self.momentum}, affine={'weight' in self}, "
f"track_running_stats={self.track_running_stats}"
)
def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]:
"""
Calculate the mean and variance of the input tensor.
Args:
x (mx.array): Input tensor.
Returns:
tuple: Tuple containing mean and variance.
"""
reduction_axes = tuple(range(0, x.ndim - 1))
means = mx.mean(x, axis=reduction_axes, keepdims=True)
var = mx.var(x, axis=reduction_axes, keepdims=True)
if self.track_running_stats and self.training:
self._running_mean = (
1 - self.momentum
) * self._running_mean + self.momentum * means
self._running_var = (
1 - self.momentum
) * self._running_var + self.momentum * var
return means, var
def __call__(self, x: mx.array) -> mx.array:
"""
Forward pass of BatchNorm.
Args:
x (mx.array): Input tensor.
Returns:
mx.array: Output tensor.
"""
if x.ndim < 2 or x.ndim > 4:
raise ValueError(
f"Expected input tensor to have 2, 3 or 4 dimensions, but got {x.ndim}"
)
if self.training or not self.track_running_stats:
means, var = self._calc_stats(x)
else:
means, var = self._running_mean, self._running_var
x = (x - means) * mx.rsqrt(var + self.eps)
return (self.weight * x + self.bias) if "weight" in self else x

View File

@ -286,7 +286,7 @@ def _reduce(loss: mx.array, reduction: str = "none"):
def hinge_loss(
inputs: mx.array, targets: mx.array, reduction: str = "none"
) -> mx.array:
"""
r"""
Computes the hinge loss between inputs and targets.
.. math::
@ -311,7 +311,7 @@ def hinge_loss(
def huber_loss(
inputs: mx.array, targets: mx.array, delta: float = 1.0, reduction: str = "none"
) -> mx.array:
"""
r"""
Computes the Huber loss between inputs and targets.
.. math::
@ -345,7 +345,7 @@ def huber_loss(
def log_cosh_loss(
inputs: mx.array, targets: mx.array, reduction: str = "none"
) -> mx.array:
"""
r"""
Computes the log cosh loss between inputs and targets.
Logcosh acts like L2 loss for small errors, ensuring stable gradients,

View File

@ -320,6 +320,143 @@ class TestNN(mlx_tests.MLXTestCase):
self.assertTrue(np.allclose(means, 3 * np.ones_like(means), atol=1e-6))
self.assertTrue(np.allclose(var, 4 * np.ones_like(var), atol=1e-6))
def test_batch_norm(self):
mx.random.seed(42)
x = mx.random.normal((5, 4), dtype=mx.float32)
# Batch norm
bn = nn.BatchNorm(num_features=4, affine=True)
self.assertTrue(mx.allclose(bn._running_mean, mx.zeros_like(bn._running_mean)))
self.assertTrue(mx.allclose(bn._running_var, mx.ones_like(bn._running_var)))
y = bn(x)
expected_y = mx.array(
[
[-0.439520, 1.647328, -0.955515, 1.966031],
[-1.726690, -1.449826, -0.234026, -0.723364],
[0.938414, -0.349603, -0.354470, -0.175369],
[0.305006, 0.234914, -0.393017, -0.459385],
[0.922789, -0.082813, 1.937028, -0.607913],
],
)
expected_mean = mx.array([0.008929, 0.005680, -0.016092, 0.027778])
expected_var = mx.array([0.928435, 1.00455, 1.04117, 0.94258])
self.assertTrue(x.shape == y.shape)
self.assertTrue(mx.allclose(y, expected_y, atol=1e-5))
self.assertTrue(mx.allclose(bn._running_mean, expected_mean, atol=1e-5))
self.assertTrue(mx.allclose(bn._running_var, expected_var, atol=1e-5))
# test eval mode
bn.eval()
y = bn(x)
expected_y = mx.array(
[
[-0.15984, 1.73159, -1.25456, 1.57891],
[-0.872193, -1.4281, -0.414439, -0.228678],
[0.602743, -0.30566, -0.554687, 0.139639],
[0.252199, 0.29066, -0.599572, -0.0512532],
[0.594096, -0.0334829, 2.11359, -0.151081],
]
)
self.assertTrue(x.shape == y.shape)
self.assertTrue(mx.allclose(y, expected_y, atol=1e-5))
# test_no_affine
bn = nn.BatchNorm(num_features=4, affine=False)
y = bn(x)
expected_y = mx.array(
[
[-0.439520, 1.647328, -0.955515, 1.966031],
[-1.726690, -1.449826, -0.234026, -0.723364],
[0.938414, -0.349603, -0.354470, -0.175369],
[0.305006, 0.234914, -0.393017, -0.459385],
[0.922789, -0.082813, 1.937028, -0.607913],
]
)
self.assertTrue(x.shape == y.shape)
self.assertTrue(mx.allclose(y, expected_y, atol=1e-5))
# test with 3D input
mx.random.seed(42)
N = 2
L = 4
C = 5
x = mx.random.normal((N, L, C), dtype=mx.float32)
# Batch norm
bn = nn.BatchNorm(num_features=C, affine=True)
self.assertTrue(mx.allclose(bn._running_mean, mx.zeros_like(bn._running_mean)))
self.assertTrue(mx.allclose(bn._running_var, mx.ones_like(bn._running_var)))
y = bn(x)
self.assertTrue(x.shape == y.shape)
expected_y = mx.array(
[
[
[-0.335754, 0.342054, 1.02653, 0.628588, -1.63899],
[1.92092, 0.432319, 0.343043, 1.95489, 1.0696],
[-0.853748, 1.3661, 0.868569, 0.0199196, -0.887284],
[0.459206, -0.684822, -0.706354, -0.271531, 0.566341],
],
[
[-0.921179, 0.684951, -0.77466, -0.490372, -0.247032],
[1.10839, -2.13179, 0.628924, -1.62639, -0.539708],
[-0.348943, 0.412194, -2.03818, 0.524972, 1.64568],
[-1.02889, -0.421, 0.652127, -0.740079, 0.0313996],
],
]
)
self.assertTrue(mx.allclose(y, expected_y, atol=1e-5))
expected_mean = mx.array(
[[[0.00207845, -5.3259e-05, 0.04755, -0.0697296, 0.0236228]]]
)
expected_var = mx.array([[[0.968415, 1.05322, 0.96913, 0.932305, 0.967224]]])
self.assertTrue(mx.allclose(bn._running_mean, expected_mean, atol=1e-5))
self.assertTrue(mx.allclose(bn._running_var, expected_var, atol=1e-5))
x = mx.random.normal((N, L, C, L, C), dtype=mx.float32)
with self.assertRaises(ValueError):
y = bn(x)
def test_batch_norm_stats(self):
batch_size = 2
num_features = 4
h = 3
w = 3
momentum = 0.1
batch_norm = nn.BatchNorm(num_features)
batch_norm.train()
running_mean = np.array(batch_norm._running_mean)
running_var = np.array(batch_norm._running_var)
data = mx.random.normal((batch_size, num_features))
normalized_data = batch_norm(data)
np_data = np.array(data)
means = np.mean(np_data, axis=0)
variances = np.var(np_data, axis=0)
running_mean = (1 - momentum) * running_mean + momentum * means
running_var = (1 - momentum) * running_var + momentum * variances
self.assertTrue(np.allclose(batch_norm._running_mean, running_mean, atol=1e-5))
self.assertTrue(np.allclose(batch_norm._running_var, running_var, atol=1e-5))
batch_norm = nn.BatchNorm(num_features)
batch_norm.train()
running_mean = np.array(batch_norm._running_mean)
running_var = np.array(batch_norm._running_var)
data = mx.random.normal((batch_size, h, w, num_features))
normalized_data = batch_norm(data)
np_data = np.array(data)
means = np.mean(np_data, axis=(0, 1, 2))
variances = np.var(np_data, axis=(0, 1, 2))
running_mean = (1 - momentum) * running_mean + momentum * means
running_var = (1 - momentum) * running_var + momentum * variances
self.assertTrue(np.allclose(batch_norm._running_mean, running_mean, atol=1e-5))
self.assertTrue(np.allclose(batch_norm._running_var, running_var, atol=1e-5))
def test_conv1d(self):
N = 5
L = 12