mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 12:06:42 +08:00
updated BN implementation to be more generic ^^
This commit is contained in:
parent
7b0f8bda9c
commit
82ca771e69
@ -182,7 +182,7 @@ class GroupNorm(Module):
|
|||||||
return (self.weight * x + self.bias) if "weight" in self else x
|
return (self.weight * x + self.bias) if "weight" in self else x
|
||||||
|
|
||||||
|
|
||||||
class BatchNorm1d(Module):
|
class BatchNorm(Module):
|
||||||
r"""Applies Batch Normalization over a 2D or 3D input.
|
r"""Applies Batch Normalization over a 2D or 3D input.
|
||||||
|
|
||||||
Computes
|
Computes
|
||||||
@ -221,11 +221,13 @@ class BatchNorm1d(Module):
|
|||||||
track_running_stats: bool = True,
|
track_running_stats: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.num_features = num_features
|
self.num_features = num_features
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.momentum = momentum
|
self.momentum = momentum
|
||||||
self.affine = affine
|
self.affine = affine
|
||||||
self.track_running_stats = track_running_stats
|
self.track_running_stats = track_running_stats
|
||||||
|
self.dims_expanded = False
|
||||||
|
|
||||||
if self.affine:
|
if self.affine:
|
||||||
self.weight = mx.ones((num_features,))
|
self.weight = mx.ones((num_features,))
|
||||||
@ -238,6 +240,37 @@ class BatchNorm1d(Module):
|
|||||||
def _extra_repr(self):
|
def _extra_repr(self):
|
||||||
return f"{self.num_features}, eps={self.eps}, momentum={self.momentum}, affine={'weight' in self}, track_running_stats={self.track_running_stats}"
|
return f"{self.num_features}, eps={self.eps}, momentum={self.momentum}, affine={'weight' in self}, track_running_stats={self.track_running_stats}"
|
||||||
|
|
||||||
|
def _check_and_expand_dims(self, x: mx.array):
|
||||||
|
"""
|
||||||
|
Check if the input is a 2D or 3D tensor and expand the weight, bias, running mean, and running variance accordingly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (mx.array): Input tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
num_dims = len(x.shape)
|
||||||
|
dims_dict = {
|
||||||
|
2: ((1, self.num_features), (0,)),
|
||||||
|
3: ((1, self.num_features, 1), (0, 2)),
|
||||||
|
4: ((1, self.num_features, 1, 1), (0, 2, 3)),
|
||||||
|
}
|
||||||
|
|
||||||
|
if num_dims not in dims_dict:
|
||||||
|
raise ValueError(f"expected num_dims to be 2, 3, or 4 (got {num_dims})")
|
||||||
|
|
||||||
|
shape, self.reduction_axes = dims_dict[num_dims]
|
||||||
|
|
||||||
|
if self.affine:
|
||||||
|
self.weight = mx.expand_dims(self.weight, self.reduction_axes)
|
||||||
|
self.bias = mx.expand_dims(self.bias, self.reduction_axes)
|
||||||
|
|
||||||
|
if self.track_running_stats:
|
||||||
|
self.running_mean = mx.expand_dims(self.running_mean, self.reduction_axes)
|
||||||
|
self.running_var = mx.expand_dims(self.running_var, self.reduction_axes)
|
||||||
|
|
||||||
|
self.dims_expanded = True
|
||||||
|
|
||||||
|
|
||||||
def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]:
|
def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]:
|
||||||
"""
|
"""
|
||||||
Calculate the mean and variance of the input tensor.
|
Calculate the mean and variance of the input tensor.
|
||||||
@ -249,12 +282,8 @@ class BatchNorm1d(Module):
|
|||||||
tuple: Tuple containing mean and variance.
|
tuple: Tuple containing mean and variance.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if len(x.shape) == 2:
|
means = mx.mean(x, axis=self.reduction_axes, keepdims=True)
|
||||||
means = mx.mean(x, axis=0, keepdims=True)
|
var = mx.var(x, axis=self.reduction_axes, keepdims=True)
|
||||||
var = mx.var(x, axis=0, keepdims=True)
|
|
||||||
else:
|
|
||||||
means = mx.mean(x, axis=(0, 2), keepdims=True)
|
|
||||||
var = mx.var(x, axis=(0, 2), keepdims=True)
|
|
||||||
|
|
||||||
if self.track_running_stats and self.training:
|
if self.track_running_stats and self.training:
|
||||||
self.running_mean = (
|
self.running_mean = (
|
||||||
@ -265,29 +294,9 @@ class BatchNorm1d(Module):
|
|||||||
) * self.running_var + self.momentum * var
|
) * self.running_var + self.momentum * var
|
||||||
return means, var
|
return means, var
|
||||||
|
|
||||||
def _check_and_expand_dims(self, x: mx.array):
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
"""
|
"""
|
||||||
Check if the input is a 2D or 3D tensor and expand the weight, bias, running mean, and running variance accordingly.
|
Forward pass of BatchNorm.
|
||||||
|
|
||||||
Args:
|
|
||||||
x (mx.array): Input tensor.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if x.ndim != 2 and x.ndim != 3:
|
|
||||||
raise ValueError(f"expected 2D or 3D input (got {x.ndim}D input)")
|
|
||||||
|
|
||||||
if x.ndim == 3 and self.weight.ndim != x.ndim:
|
|
||||||
self.weight = mx.expand_dims(self.weight, [0, 2])
|
|
||||||
self.bias = mx.expand_dims(self.bias, [0, 2])
|
|
||||||
|
|
||||||
if self.track_running_stats:
|
|
||||||
if x.ndim == 3 and self.running_mean.ndim != x.ndim:
|
|
||||||
self.running_mean = mx.expand_dims(self.running_mean, [0, 2])
|
|
||||||
self.running_var = mx.expand_dims(self.running_var, [0, 2])
|
|
||||||
|
|
||||||
def __call__(self, x: mx.array):
|
|
||||||
"""
|
|
||||||
Forward pass of BatchNorm1d.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (mx.array): Input tensor.
|
x (mx.array): Input tensor.
|
||||||
@ -296,6 +305,7 @@ class BatchNorm1d(Module):
|
|||||||
mx.array: Output tensor.
|
mx.array: Output tensor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if not self.dims_expanded:
|
||||||
self._check_and_expand_dims(x)
|
self._check_and_expand_dims(x)
|
||||||
|
|
||||||
if self.training or not self.track_running_stats:
|
if self.training or not self.track_running_stats:
|
||||||
|
@ -325,7 +325,7 @@ class TestNN(mlx_tests.MLXTestCase):
|
|||||||
x = mx.random.normal((5, 4), dtype=mx.float32)
|
x = mx.random.normal((5, 4), dtype=mx.float32)
|
||||||
|
|
||||||
# Batch norm
|
# Batch norm
|
||||||
bn = nn.BatchNorm1d(num_features=4, affine=True)
|
bn = nn.BatchNorm(num_features=4, affine=True)
|
||||||
self.assertTrue(mx.allclose(bn.running_mean, mx.zeros_like(bn.running_mean)))
|
self.assertTrue(mx.allclose(bn.running_mean, mx.zeros_like(bn.running_mean)))
|
||||||
self.assertTrue(mx.allclose(bn.running_var, mx.ones_like(bn.running_var)))
|
self.assertTrue(mx.allclose(bn.running_var, mx.ones_like(bn.running_var)))
|
||||||
y = bn(x)
|
y = bn(x)
|
||||||
@ -362,7 +362,7 @@ class TestNN(mlx_tests.MLXTestCase):
|
|||||||
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
|
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
|
||||||
|
|
||||||
# test_no_affine
|
# test_no_affine
|
||||||
bn = nn.BatchNorm1d(num_features=4, affine=False)
|
bn = nn.BatchNorm(num_features=4, affine=False)
|
||||||
y = bn(x)
|
y = bn(x)
|
||||||
expected_y = mx.array(
|
expected_y = mx.array(
|
||||||
[
|
[
|
||||||
@ -381,7 +381,7 @@ class TestNN(mlx_tests.MLXTestCase):
|
|||||||
x = mx.random.normal((2, 4, 3), dtype=mx.float32)
|
x = mx.random.normal((2, 4, 3), dtype=mx.float32)
|
||||||
|
|
||||||
# Batch norm
|
# Batch norm
|
||||||
bn = nn.BatchNorm1d(num_features=4, affine=True)
|
bn = nn.BatchNorm(num_features=4, affine=True)
|
||||||
self.assertTrue(mx.allclose(bn.running_mean, mx.zeros_like(bn.running_mean)))
|
self.assertTrue(mx.allclose(bn.running_mean, mx.zeros_like(bn.running_mean)))
|
||||||
self.assertTrue(mx.allclose(bn.running_var, mx.ones_like(bn.running_var)))
|
self.assertTrue(mx.allclose(bn.running_var, mx.ones_like(bn.running_var)))
|
||||||
y = bn(x)
|
y = bn(x)
|
||||||
|
Loading…
Reference in New Issue
Block a user