mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 12:06:42 +08:00
added test cases for batch norm on 3D input & refactored code ^^
This commit is contained in:
parent
eca773b62c
commit
7ec3cadf98
@ -253,12 +253,32 @@ class BatchNorm1d(Module):
|
||||
if self.track_running_stats and self.training:
|
||||
self.running_mean = (
|
||||
1 - self.momentum
|
||||
) * self.running_mean + self.momentum * means.squeeze()
|
||||
) * self.running_mean + self.momentum * means
|
||||
self.running_var = (
|
||||
1 - self.momentum
|
||||
) * self.running_var + self.momentum * var.squeeze()
|
||||
) * self.running_var + self.momentum * var
|
||||
return means, var
|
||||
|
||||
def _check_and_expand_dims(self, x: mx.array):
|
||||
"""
|
||||
Check if the input is a 2D or 3D tensor and expand the weight, bias, running mean, and running variance accordingly.
|
||||
|
||||
Args:
|
||||
x (mx.array): Input tensor.
|
||||
"""
|
||||
|
||||
if x.ndim != 2 and x.ndim != 3:
|
||||
raise ValueError(f"expected 2D or 3D input (got {x.ndim}D input)")
|
||||
|
||||
if x.ndim == 3 and self.weight.ndim != x.ndim:
|
||||
self.weight = mx.expand_dims(self.weight, [0, 2])
|
||||
self.bias = mx.expand_dims(self.bias, [0, 2])
|
||||
|
||||
if self.track_running_stats:
|
||||
if x.ndim == 3 and self.running_mean.ndim != x.ndim:
|
||||
self.running_mean = mx.expand_dims(self.running_mean, [0, 2])
|
||||
self.running_var = mx.expand_dims(self.running_var, [0, 2])
|
||||
|
||||
def __call__(self, x: mx.array):
|
||||
"""
|
||||
Forward pass of BatchNorm1d.
|
||||
@ -270,12 +290,7 @@ class BatchNorm1d(Module):
|
||||
mx.array: Output tensor.
|
||||
"""
|
||||
|
||||
if x.ndim != 2 and x.ndim != 3:
|
||||
raise ValueError(f"expected 2D or 3D input (got {x.ndim}D input)")
|
||||
|
||||
if x.ndim == 3:
|
||||
self.weight = mx.expand_dims(self.weight, [0, 2])
|
||||
self.bias = mx.expand_dims(self.bias, [0, 2])
|
||||
self._check_and_expand_dims(x)
|
||||
|
||||
if self.training or not self.track_running_stats:
|
||||
means, var = self._calc_stats(x)
|
||||
|
@ -376,6 +376,40 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(x.shape == y.shape)
|
||||
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
|
||||
|
||||
# test with 3D input
|
||||
mx.random.seed(42)
|
||||
x = mx.random.normal((2, 4, 3), dtype=mx.float32)
|
||||
|
||||
# Batch norm
|
||||
bn = nn.BatchNorm1d(num_features=4, affine=True)
|
||||
self.assertTrue(mx.allclose(bn.running_mean, mx.zeros_like(bn.running_mean)))
|
||||
self.assertTrue(mx.allclose(bn.running_var, mx.ones_like(bn.running_var)))
|
||||
y = bn(x)
|
||||
self.assertTrue(x.shape == y.shape)
|
||||
expected_y = mx.array(
|
||||
[
|
||||
[
|
||||
[-0.285056, -0.657241, 0.584881],
|
||||
[1.079424, 0.795527, 0.163417],
|
||||
[-0.351929, 0.669030, 1.713490],
|
||||
[-0.679080, -1.467115, 1.077580],
|
||||
],
|
||||
[
|
||||
[-0.091968, -1.362007, 1.811391],
|
||||
[-1.654407, -1.017945, 0.633983],
|
||||
[-1.309168, 0.148356, -0.869779],
|
||||
[-0.742132, 1.037774, 0.772974],
|
||||
],
|
||||
]
|
||||
)
|
||||
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
|
||||
expected_mean = mx.array(
|
||||
[[[0.0362097], [0.0360611], [0.0166926], [-0.0111884]]]
|
||||
)
|
||||
expected_var = mx.array([[[1.07218], [0.992639], [1.01724], [1.16217]]])
|
||||
self.assertTrue(np.allclose(bn.running_mean, expected_mean, atol=1e-5))
|
||||
self.assertTrue(np.allclose(bn.running_var, expected_var, atol=1e-5))
|
||||
|
||||
def test_conv1d(self):
|
||||
N = 5
|
||||
L = 12
|
||||
|
Loading…
Reference in New Issue
Block a user