updated BN implementation to be more generic ^^

This commit is contained in:
m0saan 2023-12-22 09:58:04 +01:00
parent 7b0f8bda9c
commit 82ca771e69
2 changed files with 47 additions and 37 deletions

View File

@ -182,7 +182,7 @@ class GroupNorm(Module):
return (self.weight * x + self.bias) if "weight" in self else x return (self.weight * x + self.bias) if "weight" in self else x
class BatchNorm1d(Module): class BatchNorm(Module):
r"""Applies Batch Normalization over a 2D or 3D input. r"""Applies Batch Normalization over a 2D or 3D input.
Computes Computes
@ -221,11 +221,13 @@ class BatchNorm1d(Module):
track_running_stats: bool = True, track_running_stats: bool = True,
): ):
super().__init__() super().__init__()
self.num_features = num_features self.num_features = num_features
self.eps = eps self.eps = eps
self.momentum = momentum self.momentum = momentum
self.affine = affine self.affine = affine
self.track_running_stats = track_running_stats self.track_running_stats = track_running_stats
self.dims_expanded = False
if self.affine: if self.affine:
self.weight = mx.ones((num_features,)) self.weight = mx.ones((num_features,))
@ -238,6 +240,37 @@ class BatchNorm1d(Module):
def _extra_repr(self): def _extra_repr(self):
return f"{self.num_features}, eps={self.eps}, momentum={self.momentum}, affine={'weight' in self}, track_running_stats={self.track_running_stats}" return f"{self.num_features}, eps={self.eps}, momentum={self.momentum}, affine={'weight' in self}, track_running_stats={self.track_running_stats}"
def _check_and_expand_dims(self, x: mx.array):
"""
Check if the input is a 2D or 3D tensor and expand the weight, bias, running mean, and running variance accordingly.
Args:
x (mx.array): Input tensor.
"""
num_dims = len(x.shape)
dims_dict = {
2: ((1, self.num_features), (0,)),
3: ((1, self.num_features, 1), (0, 2)),
4: ((1, self.num_features, 1, 1), (0, 2, 3)),
}
if num_dims not in dims_dict:
raise ValueError(f"expected num_dims to be 2, 3, or 4 (got {num_dims})")
shape, self.reduction_axes = dims_dict[num_dims]
if self.affine:
self.weight = mx.expand_dims(self.weight, self.reduction_axes)
self.bias = mx.expand_dims(self.bias, self.reduction_axes)
if self.track_running_stats:
self.running_mean = mx.expand_dims(self.running_mean, self.reduction_axes)
self.running_var = mx.expand_dims(self.running_var, self.reduction_axes)
self.dims_expanded = True
def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]: def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]:
""" """
Calculate the mean and variance of the input tensor. Calculate the mean and variance of the input tensor.
@ -249,12 +282,8 @@ class BatchNorm1d(Module):
tuple: Tuple containing mean and variance. tuple: Tuple containing mean and variance.
""" """
if len(x.shape) == 2: means = mx.mean(x, axis=self.reduction_axes, keepdims=True)
means = mx.mean(x, axis=0, keepdims=True) var = mx.var(x, axis=self.reduction_axes, keepdims=True)
var = mx.var(x, axis=0, keepdims=True)
else:
means = mx.mean(x, axis=(0, 2), keepdims=True)
var = mx.var(x, axis=(0, 2), keepdims=True)
if self.track_running_stats and self.training: if self.track_running_stats and self.training:
self.running_mean = ( self.running_mean = (
@ -265,29 +294,9 @@ class BatchNorm1d(Module):
) * self.running_var + self.momentum * var ) * self.running_var + self.momentum * var
return means, var return means, var
def _check_and_expand_dims(self, x: mx.array): def __call__(self, x: mx.array) -> mx.array:
""" """
Check if the input is a 2D or 3D tensor and expand the weight, bias, running mean, and running variance accordingly. Forward pass of BatchNorm.
Args:
x (mx.array): Input tensor.
"""
if x.ndim != 2 and x.ndim != 3:
raise ValueError(f"expected 2D or 3D input (got {x.ndim}D input)")
if x.ndim == 3 and self.weight.ndim != x.ndim:
self.weight = mx.expand_dims(self.weight, [0, 2])
self.bias = mx.expand_dims(self.bias, [0, 2])
if self.track_running_stats:
if x.ndim == 3 and self.running_mean.ndim != x.ndim:
self.running_mean = mx.expand_dims(self.running_mean, [0, 2])
self.running_var = mx.expand_dims(self.running_var, [0, 2])
def __call__(self, x: mx.array):
"""
Forward pass of BatchNorm1d.
Args: Args:
x (mx.array): Input tensor. x (mx.array): Input tensor.
@ -296,7 +305,8 @@ class BatchNorm1d(Module):
mx.array: Output tensor. mx.array: Output tensor.
""" """
self._check_and_expand_dims(x) if not self.dims_expanded:
self._check_and_expand_dims(x)
if self.training or not self.track_running_stats: if self.training or not self.track_running_stats:
means, var = self._calc_stats(x) means, var = self._calc_stats(x)

View File

@ -325,7 +325,7 @@ class TestNN(mlx_tests.MLXTestCase):
x = mx.random.normal((5, 4), dtype=mx.float32) x = mx.random.normal((5, 4), dtype=mx.float32)
# Batch norm # Batch norm
bn = nn.BatchNorm1d(num_features=4, affine=True) bn = nn.BatchNorm(num_features=4, affine=True)
self.assertTrue(mx.allclose(bn.running_mean, mx.zeros_like(bn.running_mean))) self.assertTrue(mx.allclose(bn.running_mean, mx.zeros_like(bn.running_mean)))
self.assertTrue(mx.allclose(bn.running_var, mx.ones_like(bn.running_var))) self.assertTrue(mx.allclose(bn.running_var, mx.ones_like(bn.running_var)))
y = bn(x) y = bn(x)
@ -362,7 +362,7 @@ class TestNN(mlx_tests.MLXTestCase):
self.assertTrue(np.allclose(y, expected_y, atol=1e-5)) self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
# test_no_affine # test_no_affine
bn = nn.BatchNorm1d(num_features=4, affine=False) bn = nn.BatchNorm(num_features=4, affine=False)
y = bn(x) y = bn(x)
expected_y = mx.array( expected_y = mx.array(
[ [
@ -381,7 +381,7 @@ class TestNN(mlx_tests.MLXTestCase):
x = mx.random.normal((2, 4, 3), dtype=mx.float32) x = mx.random.normal((2, 4, 3), dtype=mx.float32)
# Batch norm # Batch norm
bn = nn.BatchNorm1d(num_features=4, affine=True) bn = nn.BatchNorm(num_features=4, affine=True)
self.assertTrue(mx.allclose(bn.running_mean, mx.zeros_like(bn.running_mean))) self.assertTrue(mx.allclose(bn.running_mean, mx.zeros_like(bn.running_mean)))
self.assertTrue(mx.allclose(bn.running_var, mx.ones_like(bn.running_var))) self.assertTrue(mx.allclose(bn.running_var, mx.ones_like(bn.running_var)))
y = bn(x) y = bn(x)