mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 11:18:06 +08:00
implemented batchnorm layer
This commit is contained in:
parent
22fee5a383
commit
2b617b63bd
@ -1,5 +1,7 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from mlx.nn.layers.base import Module
|
from mlx.nn.layers.base import Module
|
||||||
|
|
||||||
@ -97,7 +99,7 @@ class GroupNorm(Module):
|
|||||||
where :math:`\gamma` and :math:`\beta` are learned per feature dimension
|
where :math:`\gamma` and :math:`\beta` are learned per feature dimension
|
||||||
parameters initialized at 1 and 0 respectively. However, the mean and
|
parameters initialized at 1 and 0 respectively. However, the mean and
|
||||||
variance are computed over the spatial dimensions and each group of
|
variance are computed over the spatial dimensions and each group of
|
||||||
features. In particular, the input is split into num_groups across the
|
features. In particular, the input is split into num_groups accross the
|
||||||
feature dimension.
|
feature dimension.
|
||||||
|
|
||||||
The feature dimension is assumed to be the last dimension and the dimensions
|
The feature dimension is assumed to be the last dimension and the dimensions
|
||||||
@ -178,3 +180,95 @@ class GroupNorm(Module):
|
|||||||
)
|
)
|
||||||
x = group_norm(x)
|
x = group_norm(x)
|
||||||
return (self.weight * x + self.bias) if "weight" in self else x
|
return (self.weight * x + self.bias) if "weight" in self else x
|
||||||
|
|
||||||
|
|
||||||
|
class BatchNorm1d(Module):
|
||||||
|
r"""Applies Batch Normalization [1] to the inputs.
|
||||||
|
|
||||||
|
Computes
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta,
|
||||||
|
|
||||||
|
where :math:`\gamma` and :math:`\beta` are learned per feature dimension
|
||||||
|
parameters initialized at 1 and 0 respectively.
|
||||||
|
|
||||||
|
[1]: https://arxiv.org/abs/1502.03167
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_features (int): The feature dimension of the input to normalize over.
|
||||||
|
eps (float, optional): A small additive constant for numerical stability. Default is 1e-5.
|
||||||
|
momentum (float, optional): The momentum for updating the running mean and variance. Default is 0.1.
|
||||||
|
affine (bool, optional): If True, learn an affine transform to apply after the normalization. Default is True.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> import mlx.core as mx
|
||||||
|
>>> import mlx.nn as nn
|
||||||
|
|
||||||
|
>>> # With Learnable Parameters
|
||||||
|
>>> m = nn.BatchNorm1d(100)
|
||||||
|
>>> # Without Learnable Parameters
|
||||||
|
>>> m = nn.BatchNorm1d(4, affine=False)
|
||||||
|
>>> input = mx.random.normal(20, 4)
|
||||||
|
>>> output = m(input)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_features: int,
|
||||||
|
eps: float = 1e-5,
|
||||||
|
momentum: float = 0.1,
|
||||||
|
affine: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if affine:
|
||||||
|
self.bias = mx.zeros((num_features,))
|
||||||
|
self.weight = mx.ones((num_features,))
|
||||||
|
|
||||||
|
self.num_features = num_features
|
||||||
|
self.eps = eps
|
||||||
|
self.momentum = momentum
|
||||||
|
self.running_mean = mx.zeros((num_features,))
|
||||||
|
self.running_var = mx.ones((num_features,))
|
||||||
|
|
||||||
|
def _extra_repr(self):
|
||||||
|
return f"num_features={self.num_features}, eps={self.eps}, momentum={self.momentum}, affine={'weight' in self}"
|
||||||
|
|
||||||
|
def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]:
|
||||||
|
"""
|
||||||
|
Calculate the mean and variance of the input tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (mx.array): Input tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: Tuple containing mean and variance.
|
||||||
|
"""
|
||||||
|
means = mx.mean(x, axis=0, keepdims=True)
|
||||||
|
var = mx.var(x, axis=0, keepdims=True)
|
||||||
|
self.running_mean = (
|
||||||
|
self.momentum * self.running_mean + (1 - self.momentum) * means
|
||||||
|
)
|
||||||
|
self.running_var = self.momentum * self.running_var + (1 - self.momentum) * var
|
||||||
|
return means, var
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array):
|
||||||
|
"""
|
||||||
|
Forward pass of BatchNorm1d.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (mx.array): Input tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
mx.array: Output tensor.
|
||||||
|
"""
|
||||||
|
if x.ndim != 2:
|
||||||
|
raise ValueError("BatchNorm1d only supports 2D inputs")
|
||||||
|
|
||||||
|
means, var = self.running_mean, self.running_var
|
||||||
|
if self.training:
|
||||||
|
means, var = self._calc_stats(x)
|
||||||
|
x = (x - means) * mx.rsqrt(var + self.eps)
|
||||||
|
return (self.weight * x + self.bias) if "weight" in self else x
|
||||||
|
Loading…
Reference in New Issue
Block a user