Update batchnorm to have the running stats in parameters (#305)

This commit is contained in:
Angelos Katharopoulos 2023-12-28 14:31:10 -08:00 committed by GitHub
parent 040c3bafab
commit d29770eeaa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 65 additions and 37 deletions

View File

@ -243,8 +243,15 @@ class BatchNorm(Module):
self.bias = mx.zeros((num_features,)) self.bias = mx.zeros((num_features,))
if self.track_running_stats: if self.track_running_stats:
self._running_mean = mx.zeros((num_features,)) self.running_mean = mx.zeros((num_features,))
self._running_var = mx.ones((num_features,)) self.running_var = mx.ones((num_features,))
self.freeze(keys=["running_mean", "running_var"], recurse=False)
def unfreeze(self, *args, **kwargs):
"""Wrap unfreeze to make sure that running_mean and var are always
frozen parameters."""
super().unfreeze(*args, **kwargs)
self.freeze(keys=["running_mean", "running_var"], recurse=False)
def _extra_repr(self): def _extra_repr(self):
return ( return (
@ -255,46 +262,47 @@ class BatchNorm(Module):
def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]: def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]:
""" """
Calculate the mean and variance of the input tensor. Calculate the mean and variance of the input tensor across the batch
and spatial dimensions.
Args: Args:
x (mx.array): Input tensor. x (array): Input tensor.
Returns: Returns:
tuple: Tuple containing mean and variance. tuple: Tuple containing mean and variance.
""" """
reduction_axes = tuple(range(0, x.ndim - 1)) reduction_axes = tuple(range(0, x.ndim - 1))
means = mx.mean(x, axis=reduction_axes, keepdims=True)
mean = mx.mean(x, axis=reduction_axes, keepdims=True)
var = mx.var(x, axis=reduction_axes, keepdims=True) var = mx.var(x, axis=reduction_axes, keepdims=True)
if self.track_running_stats and self.training: return mean, var
self._running_mean = (
1 - self.momentum
) * self._running_mean + self.momentum * means
self._running_var = (
1 - self.momentum
) * self._running_var + self.momentum * var
return means, var
def __call__(self, x: mx.array) -> mx.array: def __call__(self, x: mx.array) -> mx.array:
""" """
Forward pass of BatchNorm. Forward pass of BatchNorm.
Args: Args:
x (mx.array): Input tensor. x (array): Input tensor.
Returns: Returns:
mx.array: Output tensor. array: Normalized output tensor.
""" """
if x.ndim < 2 or x.ndim > 4: if x.ndim < 2 or x.ndim > 4:
raise ValueError( raise ValueError(
f"Expected input tensor to have 2, 3 or 4 dimensions, but got {x.ndim}" f"Expected input tensor to have 2, 3 or 4 dimensions, but got {x.ndim}"
) )
if self.training or not self.track_running_stats: # Calculate the mean and variance used to normalize the input x. If we
means, var = self._calc_stats(x) # are in training mode update the running stats if needed.
else: mean, var = self._calc_stats(x)
means, var = self._running_mean, self._running_var if self.training and self.track_running_stats:
x = (x - means) * mx.rsqrt(var + self.eps) mu = self.momentum
self.running_mean = (1 - mu) * self.running_mean + mu * mean
self.running_var = (1 - mu) * self.running_var + mu * var
elif self.track_running_stats:
mean = self.running_mean
var = self.running_var
x = (x - mean) * mx.rsqrt(var + self.eps)
return (self.weight * x + self.bias) if "weight" in self else x return (self.weight * x + self.bias) if "weight" in self else x

View File

@ -326,8 +326,8 @@ class TestNN(mlx_tests.MLXTestCase):
# Batch norm # Batch norm
bn = nn.BatchNorm(num_features=4, affine=True) bn = nn.BatchNorm(num_features=4, affine=True)
self.assertTrue(mx.allclose(bn._running_mean, mx.zeros_like(bn._running_mean))) self.assertTrue(mx.allclose(bn.running_mean, mx.zeros_like(bn.running_mean)))
self.assertTrue(mx.allclose(bn._running_var, mx.ones_like(bn._running_var))) self.assertTrue(mx.allclose(bn.running_var, mx.ones_like(bn.running_var)))
y = bn(x) y = bn(x)
expected_y = mx.array( expected_y = mx.array(
[ [
@ -342,8 +342,8 @@ class TestNN(mlx_tests.MLXTestCase):
expected_var = mx.array([0.928435, 1.00455, 1.04117, 0.94258]) expected_var = mx.array([0.928435, 1.00455, 1.04117, 0.94258])
self.assertTrue(x.shape == y.shape) self.assertTrue(x.shape == y.shape)
self.assertTrue(mx.allclose(y, expected_y, atol=1e-5)) self.assertTrue(mx.allclose(y, expected_y, atol=1e-5))
self.assertTrue(mx.allclose(bn._running_mean, expected_mean, atol=1e-5)) self.assertTrue(mx.allclose(bn.running_mean, expected_mean, atol=1e-5))
self.assertTrue(mx.allclose(bn._running_var, expected_var, atol=1e-5)) self.assertTrue(mx.allclose(bn.running_var, expected_var, atol=1e-5))
# test eval mode # test eval mode
bn.eval() bn.eval()
@ -385,8 +385,8 @@ class TestNN(mlx_tests.MLXTestCase):
# Batch norm # Batch norm
bn = nn.BatchNorm(num_features=C, affine=True) bn = nn.BatchNorm(num_features=C, affine=True)
self.assertTrue(mx.allclose(bn._running_mean, mx.zeros_like(bn._running_mean))) self.assertTrue(mx.allclose(bn.running_mean, mx.zeros_like(bn.running_mean)))
self.assertTrue(mx.allclose(bn._running_var, mx.ones_like(bn._running_var))) self.assertTrue(mx.allclose(bn.running_var, mx.ones_like(bn.running_var)))
y = bn(x) y = bn(x)
self.assertTrue(x.shape == y.shape) self.assertTrue(x.shape == y.shape)
expected_y = mx.array( expected_y = mx.array(
@ -410,13 +410,33 @@ class TestNN(mlx_tests.MLXTestCase):
[[[0.00207845, -5.3259e-05, 0.04755, -0.0697296, 0.0236228]]] [[[0.00207845, -5.3259e-05, 0.04755, -0.0697296, 0.0236228]]]
) )
expected_var = mx.array([[[0.968415, 1.05322, 0.96913, 0.932305, 0.967224]]]) expected_var = mx.array([[[0.968415, 1.05322, 0.96913, 0.932305, 0.967224]]])
self.assertTrue(mx.allclose(bn._running_mean, expected_mean, atol=1e-5)) self.assertTrue(mx.allclose(bn.running_mean, expected_mean, atol=1e-5))
self.assertTrue(mx.allclose(bn._running_var, expected_var, atol=1e-5)) self.assertTrue(mx.allclose(bn.running_var, expected_var, atol=1e-5))
x = mx.random.normal((N, L, C, L, C), dtype=mx.float32) x = mx.random.normal((N, L, C, L, C), dtype=mx.float32)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
y = bn(x) y = bn(x)
# Check that the running stats are in the param dictionary
bn_parameters = bn.parameters()
self.assertIn("running_mean", bn_parameters)
self.assertIn("running_var", bn_parameters)
self.assertIn("weight", bn_parameters)
self.assertIn("bias", bn_parameters)
bn_trainable = bn.trainable_parameters()
self.assertNotIn("running_mean", bn_trainable)
self.assertNotIn("running_var", bn_trainable)
self.assertIn("weight", bn_trainable)
self.assertIn("bias", bn_trainable)
bn.unfreeze()
bn_trainable = bn.trainable_parameters()
self.assertNotIn("running_mean", bn_trainable)
self.assertNotIn("running_var", bn_trainable)
self.assertIn("weight", bn_trainable)
self.assertIn("bias", bn_trainable)
def test_batch_norm_stats(self): def test_batch_norm_stats(self):
batch_size = 2 batch_size = 2
num_features = 4 num_features = 4
@ -427,8 +447,8 @@ class TestNN(mlx_tests.MLXTestCase):
batch_norm = nn.BatchNorm(num_features) batch_norm = nn.BatchNorm(num_features)
batch_norm.train() batch_norm.train()
running_mean = np.array(batch_norm._running_mean) running_mean = np.array(batch_norm.running_mean)
running_var = np.array(batch_norm._running_var) running_var = np.array(batch_norm.running_var)
data = mx.random.normal((batch_size, num_features)) data = mx.random.normal((batch_size, num_features))
@ -438,14 +458,14 @@ class TestNN(mlx_tests.MLXTestCase):
variances = np.var(np_data, axis=0) variances = np.var(np_data, axis=0)
running_mean = (1 - momentum) * running_mean + momentum * means running_mean = (1 - momentum) * running_mean + momentum * means
running_var = (1 - momentum) * running_var + momentum * variances running_var = (1 - momentum) * running_var + momentum * variances
self.assertTrue(np.allclose(batch_norm._running_mean, running_mean, atol=1e-5)) self.assertTrue(np.allclose(batch_norm.running_mean, running_mean, atol=1e-5))
self.assertTrue(np.allclose(batch_norm._running_var, running_var, atol=1e-5)) self.assertTrue(np.allclose(batch_norm.running_var, running_var, atol=1e-5))
batch_norm = nn.BatchNorm(num_features) batch_norm = nn.BatchNorm(num_features)
batch_norm.train() batch_norm.train()
running_mean = np.array(batch_norm._running_mean) running_mean = np.array(batch_norm.running_mean)
running_var = np.array(batch_norm._running_var) running_var = np.array(batch_norm.running_var)
data = mx.random.normal((batch_size, h, w, num_features)) data = mx.random.normal((batch_size, h, w, num_features))
normalized_data = batch_norm(data) normalized_data = batch_norm(data)
@ -454,8 +474,8 @@ class TestNN(mlx_tests.MLXTestCase):
variances = np.var(np_data, axis=(0, 1, 2)) variances = np.var(np_data, axis=(0, 1, 2))
running_mean = (1 - momentum) * running_mean + momentum * means running_mean = (1 - momentum) * running_mean + momentum * means
running_var = (1 - momentum) * running_var + momentum * variances running_var = (1 - momentum) * running_var + momentum * variances
self.assertTrue(np.allclose(batch_norm._running_mean, running_mean, atol=1e-5)) self.assertTrue(np.allclose(batch_norm.running_mean, running_mean, atol=1e-5))
self.assertTrue(np.allclose(batch_norm._running_var, running_var, atol=1e-5)) self.assertTrue(np.allclose(batch_norm.running_var, running_var, atol=1e-5))
def test_conv1d(self): def test_conv1d(self):
N = 5 N = 5