mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 12:06:42 +08:00
update batch norm implementation -> fixed some bug and added support for 3D inputs
This commit is contained in:
parent
e9fd1cf02d
commit
c3c2fcf41d
@ -184,12 +184,14 @@ class GroupNorm(Module):
|
|||||||
|
|
||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
from mlx.nn.layers.base import Module
|
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx.nn.layers.base import Module
|
||||||
|
|
||||||
|
|
||||||
class BatchNorm1d(Module):
|
class BatchNorm1d(Module):
|
||||||
r"""Applies Batch Normalization [1] to the inputs.
|
r"""Applies Batch Normalization over a 2D or 3D input.
|
||||||
|
|
||||||
Computes
|
Computes
|
||||||
|
|
||||||
@ -209,8 +211,7 @@ class BatchNorm1d(Module):
|
|||||||
affine (bool, optional): If True, learn an affine transform to apply after the normalization. Default is True.
|
affine (bool, optional): If True, learn an affine transform to apply after the normalization. Default is True.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> import mlx.core as mx
|
-> TODO: Add examples.
|
||||||
>>> import mlx.nn as nn
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -219,21 +220,25 @@ class BatchNorm1d(Module):
|
|||||||
eps: float = 1e-5,
|
eps: float = 1e-5,
|
||||||
momentum: float = 0.1,
|
momentum: float = 0.1,
|
||||||
affine: bool = True,
|
affine: bool = True,
|
||||||
|
track_running_stats: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if affine:
|
|
||||||
self.bias = mx.zeros((num_features,))
|
|
||||||
self.weight = mx.ones((num_features,))
|
|
||||||
|
|
||||||
self.num_features = num_features
|
self.num_features = num_features
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.momentum = mx.array([momentum])
|
self.momentum = momentum
|
||||||
|
self.affine = affine
|
||||||
|
self.track_running_stats = track_running_stats
|
||||||
|
|
||||||
|
if self.affine:
|
||||||
|
self.weight = mx.ones((num_features,))
|
||||||
|
self.bias = mx.zeros((num_features,))
|
||||||
|
|
||||||
|
if self.track_running_stats:
|
||||||
self.running_mean = mx.zeros((num_features,))
|
self.running_mean = mx.zeros((num_features,))
|
||||||
self.running_var = mx.ones((num_features,))
|
self.running_var = mx.ones((num_features,))
|
||||||
print(self.running_mean.shape)
|
|
||||||
|
|
||||||
def _extra_repr(self):
|
def _extra_repr(self):
|
||||||
return f"num_features={self.num_features}, eps={self.eps}, momentum={self.momentum}, affine={'weight' in self}"
|
return f"{self.num_features}, eps={self.eps}, momentum={self.momentum}, affine={'weight' in self}, track_running_stats={self.track_running_stats}"
|
||||||
|
|
||||||
def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]:
|
def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]:
|
||||||
"""
|
"""
|
||||||
@ -245,10 +250,21 @@ class BatchNorm1d(Module):
|
|||||||
Returns:
|
Returns:
|
||||||
tuple: Tuple containing mean and variance.
|
tuple: Tuple containing mean and variance.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if len(x.shape) == 2:
|
||||||
means = mx.mean(x, axis=0, keepdims=True)
|
means = mx.mean(x, axis=0, keepdims=True)
|
||||||
var = mx.var(x, axis=0, keepdims=True)
|
var = mx.var(x, axis=0, keepdims=True)
|
||||||
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * means.squeeze()
|
else:
|
||||||
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var.squeeze()
|
means = mx.mean(x, axis=(0, 2), keepdims=True)
|
||||||
|
var = mx.var(x, axis=(0, 2), keepdims=True)
|
||||||
|
|
||||||
|
if self.track_running_stats:
|
||||||
|
self.running_mean = (
|
||||||
|
1 - self.momentum
|
||||||
|
) * self.running_mean + self.momentum * means.squeeze()
|
||||||
|
self.running_var = (
|
||||||
|
1 - self.momentum
|
||||||
|
) * self.running_var + self.momentum * var.squeeze()
|
||||||
return means, var
|
return means, var
|
||||||
|
|
||||||
def __call__(self, x: mx.array):
|
def __call__(self, x: mx.array):
|
||||||
@ -262,7 +278,14 @@ class BatchNorm1d(Module):
|
|||||||
mx.array: Output tensor.
|
mx.array: Output tensor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self.training:
|
if x.ndim != 2 and x.ndim != 3:
|
||||||
|
raise ValueError(f"expected 2D or 3D input (got {x.ndim}D input)")
|
||||||
|
|
||||||
|
if x.ndim == 3:
|
||||||
|
self.weight = mx.expand_dims(self.weight, [0, 2])
|
||||||
|
self.bias = mx.expand_dims(self.bias, [0, 2])
|
||||||
|
|
||||||
|
if self.training or not self.track_running_stats:
|
||||||
means, var = self._calc_stats(x)
|
means, var = self._calc_stats(x)
|
||||||
else:
|
else:
|
||||||
means, var = self.running_mean, self.running_var
|
means, var = self.running_mean, self.running_var
|
||||||
|
Loading…
Reference in New Issue
Block a user