mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
implement-batch-norm-layer (#217)
- Add batch normalization layer --------- Co-authored-by: Robert McCraith <mccraithrobert@gmail.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
9e6b8c9f48
commit
a123c3c7d2
@ -20,6 +20,7 @@ Layers
|
|||||||
Linear
|
Linear
|
||||||
Conv1d
|
Conv1d
|
||||||
Conv2d
|
Conv2d
|
||||||
|
BatchNorm
|
||||||
LayerNorm
|
LayerNorm
|
||||||
RMSNorm
|
RMSNorm
|
||||||
GroupNorm
|
GroupNorm
|
||||||
|
@ -36,7 +36,7 @@ from mlx.nn.layers.convolution import Conv1d, Conv2d
|
|||||||
from mlx.nn.layers.dropout import Dropout, Dropout2d
|
from mlx.nn.layers.dropout import Dropout, Dropout2d
|
||||||
from mlx.nn.layers.embedding import Embedding
|
from mlx.nn.layers.embedding import Embedding
|
||||||
from mlx.nn.layers.linear import Linear
|
from mlx.nn.layers.linear import Linear
|
||||||
from mlx.nn.layers.normalization import GroupNorm, LayerNorm, RMSNorm
|
from mlx.nn.layers.normalization import BatchNorm, GroupNorm, LayerNorm, RMSNorm
|
||||||
from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding
|
from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding
|
||||||
from mlx.nn.layers.quantized import QuantizedLinear
|
from mlx.nn.layers.quantized import QuantizedLinear
|
||||||
from mlx.nn.layers.transformer import (
|
from mlx.nn.layers.transformer import (
|
||||||
|
@ -5,7 +5,7 @@ from mlx.nn.layers.base import Module
|
|||||||
|
|
||||||
|
|
||||||
class Dropout(Module):
|
class Dropout(Module):
|
||||||
"""Randomly zero a portion of the elements during training.
|
r"""Randomly zero a portion of the elements during training.
|
||||||
|
|
||||||
The remaining elements are multiplied with :math:`\frac{1}{1-p}` where
|
The remaining elements are multiplied with :math:`\frac{1}{1-p}` where
|
||||||
:math:`p` is the probability of zeroing an element. This is done so the
|
:math:`p` is the probability of zeroing an element. This is done so the
|
||||||
@ -36,15 +36,13 @@ class Dropout(Module):
|
|||||||
|
|
||||||
|
|
||||||
class Dropout2d(Module):
|
class Dropout2d(Module):
|
||||||
"""Apply 2D channel-wise dropout during training.
|
r"""Apply 2D channel-wise dropout during training.
|
||||||
|
|
||||||
Randomly zero out entire channels independently with probability :math:`p`.
|
Randomly zero out entire channels independently with probability :math:`p`.
|
||||||
This layer expects the channels to be last, i.e. the input shape should be
|
This layer expects the channels to be last, i.e. the input shape should be
|
||||||
``NWHC`` or ``WHC`` where:
|
``NWHC`` or ``WHC`` where:``N`` is the batch dimension,``H`` is the input
|
||||||
- ``N`` is the batch dimension
|
image height,``W`` is the input image width, and``C`` is the number of
|
||||||
- ``H`` is the input image height
|
input channels
|
||||||
- ``W`` is the input image width
|
|
||||||
- ``C`` is the number of input channels
|
|
||||||
|
|
||||||
The remaining channels are scaled by :math:`\frac{1}{1-p}` to
|
The remaining channels are scaled by :math:`\frac{1}{1-p}` to
|
||||||
maintain the expected value of each element. Unlike traditional dropout,
|
maintain the expected value of each element. Unlike traditional dropout,
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from mlx.nn.layers.base import Module
|
from mlx.nn.layers.base import Module
|
||||||
|
|
||||||
@ -178,3 +180,121 @@ class GroupNorm(Module):
|
|||||||
)
|
)
|
||||||
x = group_norm(x)
|
x = group_norm(x)
|
||||||
return (self.weight * x + self.bias) if "weight" in self else x
|
return (self.weight * x + self.bias) if "weight" in self else x
|
||||||
|
|
||||||
|
|
||||||
|
class BatchNorm(Module):
|
||||||
|
r"""Applies Batch Normalization over a 2D or 3D input.
|
||||||
|
|
||||||
|
Computes
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta,
|
||||||
|
|
||||||
|
where :math:`\gamma` and :math:`\beta` are learned per feature dimension
|
||||||
|
parameters initialized at 1 and 0 respectively.
|
||||||
|
|
||||||
|
The input shape is specified as ``NC`` or ``NLC``, where ``N`` is the
|
||||||
|
batch, ``C`` is the number of features or channels, and ``L`` is the
|
||||||
|
sequence length. The output has the same shape as the input. For
|
||||||
|
four-dimensional arrays, the shape is ``NHWC``, where ``H`` and ``W`` are
|
||||||
|
the height and width respecitvely.
|
||||||
|
|
||||||
|
For more information on Batch Normalization, see the original paper `Batch
|
||||||
|
Normalization: Accelerating Deep Network Training by Reducing Internal
|
||||||
|
Covariate Shift <https://arxiv.org/abs/1502.03167>`_.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_features (int): The feature dimension to normalize over.
|
||||||
|
eps (float, optional): A small additive constant for numerical
|
||||||
|
stability. Default: ``1e-5``.
|
||||||
|
momentum (float, optional): The momentum for updating the running
|
||||||
|
mean and variance. Default: ``0.1``.
|
||||||
|
affine (bool, optional): If ``True``, apply a learned affine
|
||||||
|
transformation after the normalization. Default: ``True``.
|
||||||
|
track_running_stats (bool, optional): If ``True``, track the
|
||||||
|
running mean and variance. Default: ``True``.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> import mlx.core as mx
|
||||||
|
>>> import mlx.nn as nn
|
||||||
|
>>> x = mx.random.normal((5, 4))
|
||||||
|
>>> bn = nn.BatchNorm(num_features=4, affine=True)
|
||||||
|
>>> output = bn(x)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_features: int,
|
||||||
|
eps: float = 1e-5,
|
||||||
|
momentum: float = 0.1,
|
||||||
|
affine: bool = True,
|
||||||
|
track_running_stats: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.num_features = num_features
|
||||||
|
self.eps = eps
|
||||||
|
self.momentum = momentum
|
||||||
|
self.track_running_stats = track_running_stats
|
||||||
|
|
||||||
|
if affine:
|
||||||
|
self.weight = mx.ones((num_features,))
|
||||||
|
self.bias = mx.zeros((num_features,))
|
||||||
|
|
||||||
|
if self.track_running_stats:
|
||||||
|
self._running_mean = mx.zeros((num_features,))
|
||||||
|
self._running_var = mx.ones((num_features,))
|
||||||
|
|
||||||
|
def _extra_repr(self):
|
||||||
|
return (
|
||||||
|
f"{self.num_features}, eps={self.eps}, "
|
||||||
|
f"momentum={self.momentum}, affine={'weight' in self}, "
|
||||||
|
f"track_running_stats={self.track_running_stats}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]:
|
||||||
|
"""
|
||||||
|
Calculate the mean and variance of the input tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (mx.array): Input tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: Tuple containing mean and variance.
|
||||||
|
"""
|
||||||
|
reduction_axes = tuple(range(0, x.ndim - 1))
|
||||||
|
means = mx.mean(x, axis=reduction_axes, keepdims=True)
|
||||||
|
var = mx.var(x, axis=reduction_axes, keepdims=True)
|
||||||
|
|
||||||
|
if self.track_running_stats and self.training:
|
||||||
|
self._running_mean = (
|
||||||
|
1 - self.momentum
|
||||||
|
) * self._running_mean + self.momentum * means
|
||||||
|
self._running_var = (
|
||||||
|
1 - self.momentum
|
||||||
|
) * self._running_var + self.momentum * var
|
||||||
|
return means, var
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
"""
|
||||||
|
Forward pass of BatchNorm.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (mx.array): Input tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
mx.array: Output tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if x.ndim < 2 or x.ndim > 4:
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected input tensor to have 2, 3 or 4 dimensions, but got {x.ndim}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.training or not self.track_running_stats:
|
||||||
|
means, var = self._calc_stats(x)
|
||||||
|
else:
|
||||||
|
means, var = self._running_mean, self._running_var
|
||||||
|
x = (x - means) * mx.rsqrt(var + self.eps)
|
||||||
|
return (self.weight * x + self.bias) if "weight" in self else x
|
||||||
|
@ -286,7 +286,7 @@ def _reduce(loss: mx.array, reduction: str = "none"):
|
|||||||
def hinge_loss(
|
def hinge_loss(
|
||||||
inputs: mx.array, targets: mx.array, reduction: str = "none"
|
inputs: mx.array, targets: mx.array, reduction: str = "none"
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
"""
|
r"""
|
||||||
Computes the hinge loss between inputs and targets.
|
Computes the hinge loss between inputs and targets.
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
@ -311,7 +311,7 @@ def hinge_loss(
|
|||||||
def huber_loss(
|
def huber_loss(
|
||||||
inputs: mx.array, targets: mx.array, delta: float = 1.0, reduction: str = "none"
|
inputs: mx.array, targets: mx.array, delta: float = 1.0, reduction: str = "none"
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
"""
|
r"""
|
||||||
Computes the Huber loss between inputs and targets.
|
Computes the Huber loss between inputs and targets.
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
@ -345,7 +345,7 @@ def huber_loss(
|
|||||||
def log_cosh_loss(
|
def log_cosh_loss(
|
||||||
inputs: mx.array, targets: mx.array, reduction: str = "none"
|
inputs: mx.array, targets: mx.array, reduction: str = "none"
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
"""
|
r"""
|
||||||
Computes the log cosh loss between inputs and targets.
|
Computes the log cosh loss between inputs and targets.
|
||||||
|
|
||||||
Logcosh acts like L2 loss for small errors, ensuring stable gradients,
|
Logcosh acts like L2 loss for small errors, ensuring stable gradients,
|
||||||
|
@ -320,6 +320,143 @@ class TestNN(mlx_tests.MLXTestCase):
|
|||||||
self.assertTrue(np.allclose(means, 3 * np.ones_like(means), atol=1e-6))
|
self.assertTrue(np.allclose(means, 3 * np.ones_like(means), atol=1e-6))
|
||||||
self.assertTrue(np.allclose(var, 4 * np.ones_like(var), atol=1e-6))
|
self.assertTrue(np.allclose(var, 4 * np.ones_like(var), atol=1e-6))
|
||||||
|
|
||||||
|
def test_batch_norm(self):
|
||||||
|
mx.random.seed(42)
|
||||||
|
x = mx.random.normal((5, 4), dtype=mx.float32)
|
||||||
|
|
||||||
|
# Batch norm
|
||||||
|
bn = nn.BatchNorm(num_features=4, affine=True)
|
||||||
|
self.assertTrue(mx.allclose(bn._running_mean, mx.zeros_like(bn._running_mean)))
|
||||||
|
self.assertTrue(mx.allclose(bn._running_var, mx.ones_like(bn._running_var)))
|
||||||
|
y = bn(x)
|
||||||
|
expected_y = mx.array(
|
||||||
|
[
|
||||||
|
[-0.439520, 1.647328, -0.955515, 1.966031],
|
||||||
|
[-1.726690, -1.449826, -0.234026, -0.723364],
|
||||||
|
[0.938414, -0.349603, -0.354470, -0.175369],
|
||||||
|
[0.305006, 0.234914, -0.393017, -0.459385],
|
||||||
|
[0.922789, -0.082813, 1.937028, -0.607913],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
expected_mean = mx.array([0.008929, 0.005680, -0.016092, 0.027778])
|
||||||
|
expected_var = mx.array([0.928435, 1.00455, 1.04117, 0.94258])
|
||||||
|
self.assertTrue(x.shape == y.shape)
|
||||||
|
self.assertTrue(mx.allclose(y, expected_y, atol=1e-5))
|
||||||
|
self.assertTrue(mx.allclose(bn._running_mean, expected_mean, atol=1e-5))
|
||||||
|
self.assertTrue(mx.allclose(bn._running_var, expected_var, atol=1e-5))
|
||||||
|
|
||||||
|
# test eval mode
|
||||||
|
bn.eval()
|
||||||
|
y = bn(x)
|
||||||
|
expected_y = mx.array(
|
||||||
|
[
|
||||||
|
[-0.15984, 1.73159, -1.25456, 1.57891],
|
||||||
|
[-0.872193, -1.4281, -0.414439, -0.228678],
|
||||||
|
[0.602743, -0.30566, -0.554687, 0.139639],
|
||||||
|
[0.252199, 0.29066, -0.599572, -0.0512532],
|
||||||
|
[0.594096, -0.0334829, 2.11359, -0.151081],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(x.shape == y.shape)
|
||||||
|
self.assertTrue(mx.allclose(y, expected_y, atol=1e-5))
|
||||||
|
|
||||||
|
# test_no_affine
|
||||||
|
bn = nn.BatchNorm(num_features=4, affine=False)
|
||||||
|
y = bn(x)
|
||||||
|
expected_y = mx.array(
|
||||||
|
[
|
||||||
|
[-0.439520, 1.647328, -0.955515, 1.966031],
|
||||||
|
[-1.726690, -1.449826, -0.234026, -0.723364],
|
||||||
|
[0.938414, -0.349603, -0.354470, -0.175369],
|
||||||
|
[0.305006, 0.234914, -0.393017, -0.459385],
|
||||||
|
[0.922789, -0.082813, 1.937028, -0.607913],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.assertTrue(x.shape == y.shape)
|
||||||
|
self.assertTrue(mx.allclose(y, expected_y, atol=1e-5))
|
||||||
|
|
||||||
|
# test with 3D input
|
||||||
|
mx.random.seed(42)
|
||||||
|
N = 2
|
||||||
|
L = 4
|
||||||
|
C = 5
|
||||||
|
x = mx.random.normal((N, L, C), dtype=mx.float32)
|
||||||
|
|
||||||
|
# Batch norm
|
||||||
|
bn = nn.BatchNorm(num_features=C, affine=True)
|
||||||
|
self.assertTrue(mx.allclose(bn._running_mean, mx.zeros_like(bn._running_mean)))
|
||||||
|
self.assertTrue(mx.allclose(bn._running_var, mx.ones_like(bn._running_var)))
|
||||||
|
y = bn(x)
|
||||||
|
self.assertTrue(x.shape == y.shape)
|
||||||
|
expected_y = mx.array(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[-0.335754, 0.342054, 1.02653, 0.628588, -1.63899],
|
||||||
|
[1.92092, 0.432319, 0.343043, 1.95489, 1.0696],
|
||||||
|
[-0.853748, 1.3661, 0.868569, 0.0199196, -0.887284],
|
||||||
|
[0.459206, -0.684822, -0.706354, -0.271531, 0.566341],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[-0.921179, 0.684951, -0.77466, -0.490372, -0.247032],
|
||||||
|
[1.10839, -2.13179, 0.628924, -1.62639, -0.539708],
|
||||||
|
[-0.348943, 0.412194, -2.03818, 0.524972, 1.64568],
|
||||||
|
[-1.02889, -0.421, 0.652127, -0.740079, 0.0313996],
|
||||||
|
],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.assertTrue(mx.allclose(y, expected_y, atol=1e-5))
|
||||||
|
expected_mean = mx.array(
|
||||||
|
[[[0.00207845, -5.3259e-05, 0.04755, -0.0697296, 0.0236228]]]
|
||||||
|
)
|
||||||
|
expected_var = mx.array([[[0.968415, 1.05322, 0.96913, 0.932305, 0.967224]]])
|
||||||
|
self.assertTrue(mx.allclose(bn._running_mean, expected_mean, atol=1e-5))
|
||||||
|
self.assertTrue(mx.allclose(bn._running_var, expected_var, atol=1e-5))
|
||||||
|
|
||||||
|
x = mx.random.normal((N, L, C, L, C), dtype=mx.float32)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
y = bn(x)
|
||||||
|
|
||||||
|
def test_batch_norm_stats(self):
|
||||||
|
batch_size = 2
|
||||||
|
num_features = 4
|
||||||
|
h = 3
|
||||||
|
w = 3
|
||||||
|
momentum = 0.1
|
||||||
|
|
||||||
|
batch_norm = nn.BatchNorm(num_features)
|
||||||
|
|
||||||
|
batch_norm.train()
|
||||||
|
running_mean = np.array(batch_norm._running_mean)
|
||||||
|
running_var = np.array(batch_norm._running_var)
|
||||||
|
|
||||||
|
data = mx.random.normal((batch_size, num_features))
|
||||||
|
|
||||||
|
normalized_data = batch_norm(data)
|
||||||
|
np_data = np.array(data)
|
||||||
|
means = np.mean(np_data, axis=0)
|
||||||
|
variances = np.var(np_data, axis=0)
|
||||||
|
running_mean = (1 - momentum) * running_mean + momentum * means
|
||||||
|
running_var = (1 - momentum) * running_var + momentum * variances
|
||||||
|
self.assertTrue(np.allclose(batch_norm._running_mean, running_mean, atol=1e-5))
|
||||||
|
self.assertTrue(np.allclose(batch_norm._running_var, running_var, atol=1e-5))
|
||||||
|
|
||||||
|
batch_norm = nn.BatchNorm(num_features)
|
||||||
|
|
||||||
|
batch_norm.train()
|
||||||
|
running_mean = np.array(batch_norm._running_mean)
|
||||||
|
running_var = np.array(batch_norm._running_var)
|
||||||
|
data = mx.random.normal((batch_size, h, w, num_features))
|
||||||
|
|
||||||
|
normalized_data = batch_norm(data)
|
||||||
|
np_data = np.array(data)
|
||||||
|
means = np.mean(np_data, axis=(0, 1, 2))
|
||||||
|
variances = np.var(np_data, axis=(0, 1, 2))
|
||||||
|
running_mean = (1 - momentum) * running_mean + momentum * means
|
||||||
|
running_var = (1 - momentum) * running_var + momentum * variances
|
||||||
|
self.assertTrue(np.allclose(batch_norm._running_mean, running_mean, atol=1e-5))
|
||||||
|
self.assertTrue(np.allclose(batch_norm._running_var, running_var, atol=1e-5))
|
||||||
|
|
||||||
def test_conv1d(self):
|
def test_conv1d(self):
|
||||||
N = 5
|
N = 5
|
||||||
L = 12
|
L = 12
|
||||||
|
Loading…
Reference in New Issue
Block a user