mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Update batchnorm to have the running stats in parameters (#305)
This commit is contained in:

committed by
GitHub

parent
040c3bafab
commit
d29770eeaa
@@ -326,8 +326,8 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
|
||||
# Batch norm
|
||||
bn = nn.BatchNorm(num_features=4, affine=True)
|
||||
self.assertTrue(mx.allclose(bn._running_mean, mx.zeros_like(bn._running_mean)))
|
||||
self.assertTrue(mx.allclose(bn._running_var, mx.ones_like(bn._running_var)))
|
||||
self.assertTrue(mx.allclose(bn.running_mean, mx.zeros_like(bn.running_mean)))
|
||||
self.assertTrue(mx.allclose(bn.running_var, mx.ones_like(bn.running_var)))
|
||||
y = bn(x)
|
||||
expected_y = mx.array(
|
||||
[
|
||||
@@ -342,8 +342,8 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
expected_var = mx.array([0.928435, 1.00455, 1.04117, 0.94258])
|
||||
self.assertTrue(x.shape == y.shape)
|
||||
self.assertTrue(mx.allclose(y, expected_y, atol=1e-5))
|
||||
self.assertTrue(mx.allclose(bn._running_mean, expected_mean, atol=1e-5))
|
||||
self.assertTrue(mx.allclose(bn._running_var, expected_var, atol=1e-5))
|
||||
self.assertTrue(mx.allclose(bn.running_mean, expected_mean, atol=1e-5))
|
||||
self.assertTrue(mx.allclose(bn.running_var, expected_var, atol=1e-5))
|
||||
|
||||
# test eval mode
|
||||
bn.eval()
|
||||
@@ -385,8 +385,8 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
|
||||
# Batch norm
|
||||
bn = nn.BatchNorm(num_features=C, affine=True)
|
||||
self.assertTrue(mx.allclose(bn._running_mean, mx.zeros_like(bn._running_mean)))
|
||||
self.assertTrue(mx.allclose(bn._running_var, mx.ones_like(bn._running_var)))
|
||||
self.assertTrue(mx.allclose(bn.running_mean, mx.zeros_like(bn.running_mean)))
|
||||
self.assertTrue(mx.allclose(bn.running_var, mx.ones_like(bn.running_var)))
|
||||
y = bn(x)
|
||||
self.assertTrue(x.shape == y.shape)
|
||||
expected_y = mx.array(
|
||||
@@ -410,13 +410,33 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
[[[0.00207845, -5.3259e-05, 0.04755, -0.0697296, 0.0236228]]]
|
||||
)
|
||||
expected_var = mx.array([[[0.968415, 1.05322, 0.96913, 0.932305, 0.967224]]])
|
||||
self.assertTrue(mx.allclose(bn._running_mean, expected_mean, atol=1e-5))
|
||||
self.assertTrue(mx.allclose(bn._running_var, expected_var, atol=1e-5))
|
||||
self.assertTrue(mx.allclose(bn.running_mean, expected_mean, atol=1e-5))
|
||||
self.assertTrue(mx.allclose(bn.running_var, expected_var, atol=1e-5))
|
||||
|
||||
x = mx.random.normal((N, L, C, L, C), dtype=mx.float32)
|
||||
with self.assertRaises(ValueError):
|
||||
y = bn(x)
|
||||
|
||||
# Check that the running stats are in the param dictionary
|
||||
bn_parameters = bn.parameters()
|
||||
self.assertIn("running_mean", bn_parameters)
|
||||
self.assertIn("running_var", bn_parameters)
|
||||
self.assertIn("weight", bn_parameters)
|
||||
self.assertIn("bias", bn_parameters)
|
||||
|
||||
bn_trainable = bn.trainable_parameters()
|
||||
self.assertNotIn("running_mean", bn_trainable)
|
||||
self.assertNotIn("running_var", bn_trainable)
|
||||
self.assertIn("weight", bn_trainable)
|
||||
self.assertIn("bias", bn_trainable)
|
||||
|
||||
bn.unfreeze()
|
||||
bn_trainable = bn.trainable_parameters()
|
||||
self.assertNotIn("running_mean", bn_trainable)
|
||||
self.assertNotIn("running_var", bn_trainable)
|
||||
self.assertIn("weight", bn_trainable)
|
||||
self.assertIn("bias", bn_trainable)
|
||||
|
||||
def test_batch_norm_stats(self):
|
||||
batch_size = 2
|
||||
num_features = 4
|
||||
@@ -427,8 +447,8 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
batch_norm = nn.BatchNorm(num_features)
|
||||
|
||||
batch_norm.train()
|
||||
running_mean = np.array(batch_norm._running_mean)
|
||||
running_var = np.array(batch_norm._running_var)
|
||||
running_mean = np.array(batch_norm.running_mean)
|
||||
running_var = np.array(batch_norm.running_var)
|
||||
|
||||
data = mx.random.normal((batch_size, num_features))
|
||||
|
||||
@@ -438,14 +458,14 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
variances = np.var(np_data, axis=0)
|
||||
running_mean = (1 - momentum) * running_mean + momentum * means
|
||||
running_var = (1 - momentum) * running_var + momentum * variances
|
||||
self.assertTrue(np.allclose(batch_norm._running_mean, running_mean, atol=1e-5))
|
||||
self.assertTrue(np.allclose(batch_norm._running_var, running_var, atol=1e-5))
|
||||
self.assertTrue(np.allclose(batch_norm.running_mean, running_mean, atol=1e-5))
|
||||
self.assertTrue(np.allclose(batch_norm.running_var, running_var, atol=1e-5))
|
||||
|
||||
batch_norm = nn.BatchNorm(num_features)
|
||||
|
||||
batch_norm.train()
|
||||
running_mean = np.array(batch_norm._running_mean)
|
||||
running_var = np.array(batch_norm._running_var)
|
||||
running_mean = np.array(batch_norm.running_mean)
|
||||
running_var = np.array(batch_norm.running_var)
|
||||
data = mx.random.normal((batch_size, h, w, num_features))
|
||||
|
||||
normalized_data = batch_norm(data)
|
||||
@@ -454,8 +474,8 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
variances = np.var(np_data, axis=(0, 1, 2))
|
||||
running_mean = (1 - momentum) * running_mean + momentum * means
|
||||
running_var = (1 - momentum) * running_var + momentum * variances
|
||||
self.assertTrue(np.allclose(batch_norm._running_mean, running_mean, atol=1e-5))
|
||||
self.assertTrue(np.allclose(batch_norm._running_var, running_var, atol=1e-5))
|
||||
self.assertTrue(np.allclose(batch_norm.running_mean, running_mean, atol=1e-5))
|
||||
self.assertTrue(np.allclose(batch_norm.running_var, running_var, atol=1e-5))
|
||||
|
||||
def test_conv1d(self):
|
||||
N = 5
|
||||
|
Reference in New Issue
Block a user