Fix BN stats to not expand shape (#409)

* fix BN stats to not expand shape

* nit
This commit is contained in:
Awni Hannun 2024-01-09 11:54:51 -08:00 committed by GitHub
parent 753867123d
commit e9ca65c939
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 16 deletions

View File

@ -333,8 +333,8 @@ class BatchNorm(Module):
"""
reduction_axes = tuple(range(0, x.ndim - 1))
mean = mx.mean(x, axis=reduction_axes, keepdims=True)
var = mx.var(x, axis=reduction_axes, keepdims=True)
mean = mx.mean(x, axis=reduction_axes)
var = mx.var(x, axis=reduction_axes)
return mean, var

View File

@ -517,35 +517,36 @@ class TestLayers(mlx_tests.MLXTestCase):
batch_norm = nn.BatchNorm(num_features)
batch_norm.train()
running_mean = np.array(batch_norm.running_mean)
running_var = np.array(batch_norm.running_var)
running_mean = batch_norm.running_mean
running_var = batch_norm.running_var
data = mx.random.normal((batch_size, num_features))
normalized_data = batch_norm(data)
np_data = np.array(data)
means = np.mean(np_data, axis=0)
variances = np.var(np_data, axis=0)
means = mx.mean(data, axis=0)
variances = mx.var(data, axis=0)
running_mean = (1 - momentum) * running_mean + momentum * means
running_var = (1 - momentum) * running_var + momentum * variances
self.assertTrue(np.allclose(batch_norm.running_mean, running_mean, atol=1e-5))
self.assertTrue(np.allclose(batch_norm.running_var, running_var, atol=1e-5))
self.assertTrue(mx.allclose(batch_norm.running_mean, running_mean, atol=1e-5))
self.assertTrue(mx.allclose(batch_norm.running_var, running_var, atol=1e-5))
batch_norm = nn.BatchNorm(num_features)
batch_norm.train()
running_mean = np.array(batch_norm.running_mean)
running_var = np.array(batch_norm.running_var)
running_mean = batch_norm.running_mean
running_var = batch_norm.running_var
data = mx.random.normal((batch_size, h, w, num_features))
normalized_data = batch_norm(data)
np_data = np.array(data)
means = np.mean(np_data, axis=(0, 1, 2))
variances = np.var(np_data, axis=(0, 1, 2))
means = mx.mean(data, axis=(0, 1, 2))
variances = mx.var(data, axis=(0, 1, 2))
running_mean = (1 - momentum) * running_mean + momentum * means
running_var = (1 - momentum) * running_var + momentum * variances
self.assertTrue(np.allclose(batch_norm.running_mean, running_mean, atol=1e-5))
self.assertTrue(np.allclose(batch_norm.running_var, running_var, atol=1e-5))
self.assertTrue(mx.allclose(batch_norm.running_mean, running_mean, atol=1e-5))
self.assertTrue(mx.allclose(batch_norm.running_var, running_var, atol=1e-5))
self.assertEqual(batch_norm.running_mean.shape, running_mean.shape)
self.assertEqual(batch_norm.running_var.shape, running_var.shape)
def test_conv1d(self):
N = 5