Fix BN stats to not expand shape (#409)

* fix BN stats to not expand shape

* nit
This commit is contained in:
Awni Hannun 2024-01-09 11:54:51 -08:00 committed by GitHub
parent 753867123d
commit e9ca65c939
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 16 deletions

View File

@ -333,8 +333,8 @@ class BatchNorm(Module):
""" """
reduction_axes = tuple(range(0, x.ndim - 1)) reduction_axes = tuple(range(0, x.ndim - 1))
mean = mx.mean(x, axis=reduction_axes, keepdims=True) mean = mx.mean(x, axis=reduction_axes)
var = mx.var(x, axis=reduction_axes, keepdims=True) var = mx.var(x, axis=reduction_axes)
return mean, var return mean, var

View File

@ -517,35 +517,36 @@ class TestLayers(mlx_tests.MLXTestCase):
batch_norm = nn.BatchNorm(num_features) batch_norm = nn.BatchNorm(num_features)
batch_norm.train() batch_norm.train()
running_mean = np.array(batch_norm.running_mean) running_mean = batch_norm.running_mean
running_var = np.array(batch_norm.running_var) running_var = batch_norm.running_var
data = mx.random.normal((batch_size, num_features)) data = mx.random.normal((batch_size, num_features))
normalized_data = batch_norm(data) normalized_data = batch_norm(data)
np_data = np.array(data) means = mx.mean(data, axis=0)
means = np.mean(np_data, axis=0) variances = mx.var(data, axis=0)
variances = np.var(np_data, axis=0)
running_mean = (1 - momentum) * running_mean + momentum * means running_mean = (1 - momentum) * running_mean + momentum * means
running_var = (1 - momentum) * running_var + momentum * variances running_var = (1 - momentum) * running_var + momentum * variances
self.assertTrue(np.allclose(batch_norm.running_mean, running_mean, atol=1e-5)) self.assertTrue(mx.allclose(batch_norm.running_mean, running_mean, atol=1e-5))
self.assertTrue(np.allclose(batch_norm.running_var, running_var, atol=1e-5)) self.assertTrue(mx.allclose(batch_norm.running_var, running_var, atol=1e-5))
batch_norm = nn.BatchNorm(num_features) batch_norm = nn.BatchNorm(num_features)
batch_norm.train() batch_norm.train()
running_mean = np.array(batch_norm.running_mean) running_mean = batch_norm.running_mean
running_var = np.array(batch_norm.running_var) running_var = batch_norm.running_var
data = mx.random.normal((batch_size, h, w, num_features)) data = mx.random.normal((batch_size, h, w, num_features))
normalized_data = batch_norm(data) normalized_data = batch_norm(data)
np_data = np.array(data) means = mx.mean(data, axis=(0, 1, 2))
means = np.mean(np_data, axis=(0, 1, 2)) variances = mx.var(data, axis=(0, 1, 2))
variances = np.var(np_data, axis=(0, 1, 2))
running_mean = (1 - momentum) * running_mean + momentum * means running_mean = (1 - momentum) * running_mean + momentum * means
running_var = (1 - momentum) * running_var + momentum * variances running_var = (1 - momentum) * running_var + momentum * variances
self.assertTrue(np.allclose(batch_norm.running_mean, running_mean, atol=1e-5)) self.assertTrue(mx.allclose(batch_norm.running_mean, running_mean, atol=1e-5))
self.assertTrue(np.allclose(batch_norm.running_var, running_var, atol=1e-5)) self.assertTrue(mx.allclose(batch_norm.running_var, running_var, atol=1e-5))
self.assertEqual(batch_norm.running_mean.shape, running_mean.shape)
self.assertEqual(batch_norm.running_var.shape, running_var.shape)
def test_conv1d(self): def test_conv1d(self):
N = 5 N = 5