added test cases for batch norm on 3D input & refactored code ^^

This commit is contained in:
m0saan 2023-12-19 10:06:35 +01:00
parent eca773b62c
commit 7ec3cadf98
2 changed files with 57 additions and 8 deletions

View File

@ -253,12 +253,32 @@ class BatchNorm1d(Module):
if self.track_running_stats and self.training: if self.track_running_stats and self.training:
self.running_mean = ( self.running_mean = (
1 - self.momentum 1 - self.momentum
) * self.running_mean + self.momentum * means.squeeze() ) * self.running_mean + self.momentum * means
self.running_var = ( self.running_var = (
1 - self.momentum 1 - self.momentum
) * self.running_var + self.momentum * var.squeeze() ) * self.running_var + self.momentum * var
return means, var return means, var
def _check_and_expand_dims(self, x: mx.array):
"""
Check if the input is a 2D or 3D tensor and expand the weight, bias, running mean, and running variance accordingly.
Args:
x (mx.array): Input tensor.
"""
if x.ndim != 2 and x.ndim != 3:
raise ValueError(f"expected 2D or 3D input (got {x.ndim}D input)")
if x.ndim == 3 and self.weight.ndim != x.ndim:
self.weight = mx.expand_dims(self.weight, [0, 2])
self.bias = mx.expand_dims(self.bias, [0, 2])
if self.track_running_stats:
if x.ndim == 3 and self.running_mean.ndim != x.ndim:
self.running_mean = mx.expand_dims(self.running_mean, [0, 2])
self.running_var = mx.expand_dims(self.running_var, [0, 2])
def __call__(self, x: mx.array): def __call__(self, x: mx.array):
""" """
Forward pass of BatchNorm1d. Forward pass of BatchNorm1d.
@ -270,12 +290,7 @@ class BatchNorm1d(Module):
mx.array: Output tensor. mx.array: Output tensor.
""" """
if x.ndim != 2 and x.ndim != 3: self._check_and_expand_dims(x)
raise ValueError(f"expected 2D or 3D input (got {x.ndim}D input)")
if x.ndim == 3:
self.weight = mx.expand_dims(self.weight, [0, 2])
self.bias = mx.expand_dims(self.bias, [0, 2])
if self.training or not self.track_running_stats: if self.training or not self.track_running_stats:
means, var = self._calc_stats(x) means, var = self._calc_stats(x)

View File

@ -376,6 +376,40 @@ class TestNN(mlx_tests.MLXTestCase):
self.assertTrue(x.shape == y.shape) self.assertTrue(x.shape == y.shape)
self.assertTrue(np.allclose(y, expected_y, atol=1e-5)) self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
# test with 3D input
mx.random.seed(42)
x = mx.random.normal((2, 4, 3), dtype=mx.float32)
# Batch norm
bn = nn.BatchNorm1d(num_features=4, affine=True)
self.assertTrue(mx.allclose(bn.running_mean, mx.zeros_like(bn.running_mean)))
self.assertTrue(mx.allclose(bn.running_var, mx.ones_like(bn.running_var)))
y = bn(x)
self.assertTrue(x.shape == y.shape)
expected_y = mx.array(
[
[
[-0.285056, -0.657241, 0.584881],
[1.079424, 0.795527, 0.163417],
[-0.351929, 0.669030, 1.713490],
[-0.679080, -1.467115, 1.077580],
],
[
[-0.091968, -1.362007, 1.811391],
[-1.654407, -1.017945, 0.633983],
[-1.309168, 0.148356, -0.869779],
[-0.742132, 1.037774, 0.772974],
],
]
)
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
expected_mean = mx.array(
[[[0.0362097], [0.0360611], [0.0166926], [-0.0111884]]]
)
expected_var = mx.array([[[1.07218], [0.992639], [1.01724], [1.16217]]])
self.assertTrue(np.allclose(bn.running_mean, expected_mean, atol=1e-5))
self.assertTrue(np.allclose(bn.running_var, expected_var, atol=1e-5))
def test_conv1d(self): def test_conv1d(self):
N = 5 N = 5
L = 12 L = 12