mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 12:06:42 +08:00
added test cases for batch norm on 3D input & refactored code ^^
This commit is contained in:
parent
eca773b62c
commit
7ec3cadf98
@ -253,12 +253,32 @@ class BatchNorm1d(Module):
|
|||||||
if self.track_running_stats and self.training:
|
if self.track_running_stats and self.training:
|
||||||
self.running_mean = (
|
self.running_mean = (
|
||||||
1 - self.momentum
|
1 - self.momentum
|
||||||
) * self.running_mean + self.momentum * means.squeeze()
|
) * self.running_mean + self.momentum * means
|
||||||
self.running_var = (
|
self.running_var = (
|
||||||
1 - self.momentum
|
1 - self.momentum
|
||||||
) * self.running_var + self.momentum * var.squeeze()
|
) * self.running_var + self.momentum * var
|
||||||
return means, var
|
return means, var
|
||||||
|
|
||||||
|
def _check_and_expand_dims(self, x: mx.array):
|
||||||
|
"""
|
||||||
|
Check if the input is a 2D or 3D tensor and expand the weight, bias, running mean, and running variance accordingly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (mx.array): Input tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if x.ndim != 2 and x.ndim != 3:
|
||||||
|
raise ValueError(f"expected 2D or 3D input (got {x.ndim}D input)")
|
||||||
|
|
||||||
|
if x.ndim == 3 and self.weight.ndim != x.ndim:
|
||||||
|
self.weight = mx.expand_dims(self.weight, [0, 2])
|
||||||
|
self.bias = mx.expand_dims(self.bias, [0, 2])
|
||||||
|
|
||||||
|
if self.track_running_stats:
|
||||||
|
if x.ndim == 3 and self.running_mean.ndim != x.ndim:
|
||||||
|
self.running_mean = mx.expand_dims(self.running_mean, [0, 2])
|
||||||
|
self.running_var = mx.expand_dims(self.running_var, [0, 2])
|
||||||
|
|
||||||
def __call__(self, x: mx.array):
|
def __call__(self, x: mx.array):
|
||||||
"""
|
"""
|
||||||
Forward pass of BatchNorm1d.
|
Forward pass of BatchNorm1d.
|
||||||
@ -270,12 +290,7 @@ class BatchNorm1d(Module):
|
|||||||
mx.array: Output tensor.
|
mx.array: Output tensor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if x.ndim != 2 and x.ndim != 3:
|
self._check_and_expand_dims(x)
|
||||||
raise ValueError(f"expected 2D or 3D input (got {x.ndim}D input)")
|
|
||||||
|
|
||||||
if x.ndim == 3:
|
|
||||||
self.weight = mx.expand_dims(self.weight, [0, 2])
|
|
||||||
self.bias = mx.expand_dims(self.bias, [0, 2])
|
|
||||||
|
|
||||||
if self.training or not self.track_running_stats:
|
if self.training or not self.track_running_stats:
|
||||||
means, var = self._calc_stats(x)
|
means, var = self._calc_stats(x)
|
||||||
|
@ -376,6 +376,40 @@ class TestNN(mlx_tests.MLXTestCase):
|
|||||||
self.assertTrue(x.shape == y.shape)
|
self.assertTrue(x.shape == y.shape)
|
||||||
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
|
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
|
||||||
|
|
||||||
|
# test with 3D input
|
||||||
|
mx.random.seed(42)
|
||||||
|
x = mx.random.normal((2, 4, 3), dtype=mx.float32)
|
||||||
|
|
||||||
|
# Batch norm
|
||||||
|
bn = nn.BatchNorm1d(num_features=4, affine=True)
|
||||||
|
self.assertTrue(mx.allclose(bn.running_mean, mx.zeros_like(bn.running_mean)))
|
||||||
|
self.assertTrue(mx.allclose(bn.running_var, mx.ones_like(bn.running_var)))
|
||||||
|
y = bn(x)
|
||||||
|
self.assertTrue(x.shape == y.shape)
|
||||||
|
expected_y = mx.array(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[-0.285056, -0.657241, 0.584881],
|
||||||
|
[1.079424, 0.795527, 0.163417],
|
||||||
|
[-0.351929, 0.669030, 1.713490],
|
||||||
|
[-0.679080, -1.467115, 1.077580],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[-0.091968, -1.362007, 1.811391],
|
||||||
|
[-1.654407, -1.017945, 0.633983],
|
||||||
|
[-1.309168, 0.148356, -0.869779],
|
||||||
|
[-0.742132, 1.037774, 0.772974],
|
||||||
|
],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
|
||||||
|
expected_mean = mx.array(
|
||||||
|
[[[0.0362097], [0.0360611], [0.0166926], [-0.0111884]]]
|
||||||
|
)
|
||||||
|
expected_var = mx.array([[[1.07218], [0.992639], [1.01724], [1.16217]]])
|
||||||
|
self.assertTrue(np.allclose(bn.running_mean, expected_mean, atol=1e-5))
|
||||||
|
self.assertTrue(np.allclose(bn.running_var, expected_var, atol=1e-5))
|
||||||
|
|
||||||
def test_conv1d(self):
|
def test_conv1d(self):
|
||||||
N = 5
|
N = 5
|
||||||
L = 12
|
L = 12
|
||||||
|
Loading…
Reference in New Issue
Block a user