Update batchnorm to have the running stats in parameters (#305)

This commit is contained in:
Angelos Katharopoulos
2023-12-28 14:31:10 -08:00
committed by GitHub
parent 040c3bafab
commit d29770eeaa
2 changed files with 65 additions and 37 deletions

View File

@@ -326,8 +326,8 @@ class TestNN(mlx_tests.MLXTestCase):
# Batch norm
bn = nn.BatchNorm(num_features=4, affine=True)
self.assertTrue(mx.allclose(bn._running_mean, mx.zeros_like(bn._running_mean)))
self.assertTrue(mx.allclose(bn._running_var, mx.ones_like(bn._running_var)))
self.assertTrue(mx.allclose(bn.running_mean, mx.zeros_like(bn.running_mean)))
self.assertTrue(mx.allclose(bn.running_var, mx.ones_like(bn.running_var)))
y = bn(x)
expected_y = mx.array(
[
@@ -342,8 +342,8 @@ class TestNN(mlx_tests.MLXTestCase):
expected_var = mx.array([0.928435, 1.00455, 1.04117, 0.94258])
self.assertTrue(x.shape == y.shape)
self.assertTrue(mx.allclose(y, expected_y, atol=1e-5))
self.assertTrue(mx.allclose(bn._running_mean, expected_mean, atol=1e-5))
self.assertTrue(mx.allclose(bn._running_var, expected_var, atol=1e-5))
self.assertTrue(mx.allclose(bn.running_mean, expected_mean, atol=1e-5))
self.assertTrue(mx.allclose(bn.running_var, expected_var, atol=1e-5))
# test eval mode
bn.eval()
@@ -385,8 +385,8 @@ class TestNN(mlx_tests.MLXTestCase):
# Batch norm
bn = nn.BatchNorm(num_features=C, affine=True)
self.assertTrue(mx.allclose(bn._running_mean, mx.zeros_like(bn._running_mean)))
self.assertTrue(mx.allclose(bn._running_var, mx.ones_like(bn._running_var)))
self.assertTrue(mx.allclose(bn.running_mean, mx.zeros_like(bn.running_mean)))
self.assertTrue(mx.allclose(bn.running_var, mx.ones_like(bn.running_var)))
y = bn(x)
self.assertTrue(x.shape == y.shape)
expected_y = mx.array(
@@ -410,13 +410,33 @@ class TestNN(mlx_tests.MLXTestCase):
[[[0.00207845, -5.3259e-05, 0.04755, -0.0697296, 0.0236228]]]
)
expected_var = mx.array([[[0.968415, 1.05322, 0.96913, 0.932305, 0.967224]]])
self.assertTrue(mx.allclose(bn._running_mean, expected_mean, atol=1e-5))
self.assertTrue(mx.allclose(bn._running_var, expected_var, atol=1e-5))
self.assertTrue(mx.allclose(bn.running_mean, expected_mean, atol=1e-5))
self.assertTrue(mx.allclose(bn.running_var, expected_var, atol=1e-5))
x = mx.random.normal((N, L, C, L, C), dtype=mx.float32)
with self.assertRaises(ValueError):
y = bn(x)
# Check that the running stats are in the param dictionary
bn_parameters = bn.parameters()
self.assertIn("running_mean", bn_parameters)
self.assertIn("running_var", bn_parameters)
self.assertIn("weight", bn_parameters)
self.assertIn("bias", bn_parameters)
bn_trainable = bn.trainable_parameters()
self.assertNotIn("running_mean", bn_trainable)
self.assertNotIn("running_var", bn_trainable)
self.assertIn("weight", bn_trainable)
self.assertIn("bias", bn_trainable)
bn.unfreeze()
bn_trainable = bn.trainable_parameters()
self.assertNotIn("running_mean", bn_trainable)
self.assertNotIn("running_var", bn_trainable)
self.assertIn("weight", bn_trainable)
self.assertIn("bias", bn_trainable)
def test_batch_norm_stats(self):
batch_size = 2
num_features = 4
@@ -427,8 +447,8 @@ class TestNN(mlx_tests.MLXTestCase):
batch_norm = nn.BatchNorm(num_features)
batch_norm.train()
running_mean = np.array(batch_norm._running_mean)
running_var = np.array(batch_norm._running_var)
running_mean = np.array(batch_norm.running_mean)
running_var = np.array(batch_norm.running_var)
data = mx.random.normal((batch_size, num_features))
@@ -438,14 +458,14 @@ class TestNN(mlx_tests.MLXTestCase):
variances = np.var(np_data, axis=0)
running_mean = (1 - momentum) * running_mean + momentum * means
running_var = (1 - momentum) * running_var + momentum * variances
self.assertTrue(np.allclose(batch_norm._running_mean, running_mean, atol=1e-5))
self.assertTrue(np.allclose(batch_norm._running_var, running_var, atol=1e-5))
self.assertTrue(np.allclose(batch_norm.running_mean, running_mean, atol=1e-5))
self.assertTrue(np.allclose(batch_norm.running_var, running_var, atol=1e-5))
batch_norm = nn.BatchNorm(num_features)
batch_norm.train()
running_mean = np.array(batch_norm._running_mean)
running_var = np.array(batch_norm._running_var)
running_mean = np.array(batch_norm.running_mean)
running_var = np.array(batch_norm.running_var)
data = mx.random.normal((batch_size, h, w, num_features))
normalized_data = batch_norm(data)
@@ -454,8 +474,8 @@ class TestNN(mlx_tests.MLXTestCase):
variances = np.var(np_data, axis=(0, 1, 2))
running_mean = (1 - momentum) * running_mean + momentum * means
running_var = (1 - momentum) * running_var + momentum * variances
self.assertTrue(np.allclose(batch_norm._running_mean, running_mean, atol=1e-5))
self.assertTrue(np.allclose(batch_norm._running_var, running_var, atol=1e-5))
self.assertTrue(np.allclose(batch_norm.running_mean, running_mean, atol=1e-5))
self.assertTrue(np.allclose(batch_norm.running_var, running_var, atol=1e-5))
def test_conv1d(self):
N = 5