mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-20 18:39:45 +08:00
Fix BN stats to not expand shape (#409)
* fix BN stats to not expand shape * nit
This commit is contained in:
parent
753867123d
commit
e9ca65c939
@ -333,8 +333,8 @@ class BatchNorm(Module):
|
||||
"""
|
||||
reduction_axes = tuple(range(0, x.ndim - 1))
|
||||
|
||||
mean = mx.mean(x, axis=reduction_axes, keepdims=True)
|
||||
var = mx.var(x, axis=reduction_axes, keepdims=True)
|
||||
mean = mx.mean(x, axis=reduction_axes)
|
||||
var = mx.var(x, axis=reduction_axes)
|
||||
|
||||
return mean, var
|
||||
|
||||
|
@ -517,35 +517,36 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
batch_norm = nn.BatchNorm(num_features)
|
||||
|
||||
batch_norm.train()
|
||||
running_mean = np.array(batch_norm.running_mean)
|
||||
running_var = np.array(batch_norm.running_var)
|
||||
running_mean = batch_norm.running_mean
|
||||
running_var = batch_norm.running_var
|
||||
|
||||
data = mx.random.normal((batch_size, num_features))
|
||||
|
||||
normalized_data = batch_norm(data)
|
||||
np_data = np.array(data)
|
||||
means = np.mean(np_data, axis=0)
|
||||
variances = np.var(np_data, axis=0)
|
||||
means = mx.mean(data, axis=0)
|
||||
variances = mx.var(data, axis=0)
|
||||
running_mean = (1 - momentum) * running_mean + momentum * means
|
||||
running_var = (1 - momentum) * running_var + momentum * variances
|
||||
self.assertTrue(np.allclose(batch_norm.running_mean, running_mean, atol=1e-5))
|
||||
self.assertTrue(np.allclose(batch_norm.running_var, running_var, atol=1e-5))
|
||||
self.assertTrue(mx.allclose(batch_norm.running_mean, running_mean, atol=1e-5))
|
||||
self.assertTrue(mx.allclose(batch_norm.running_var, running_var, atol=1e-5))
|
||||
|
||||
batch_norm = nn.BatchNorm(num_features)
|
||||
|
||||
batch_norm.train()
|
||||
running_mean = np.array(batch_norm.running_mean)
|
||||
running_var = np.array(batch_norm.running_var)
|
||||
running_mean = batch_norm.running_mean
|
||||
running_var = batch_norm.running_var
|
||||
data = mx.random.normal((batch_size, h, w, num_features))
|
||||
|
||||
normalized_data = batch_norm(data)
|
||||
np_data = np.array(data)
|
||||
means = np.mean(np_data, axis=(0, 1, 2))
|
||||
variances = np.var(np_data, axis=(0, 1, 2))
|
||||
means = mx.mean(data, axis=(0, 1, 2))
|
||||
variances = mx.var(data, axis=(0, 1, 2))
|
||||
running_mean = (1 - momentum) * running_mean + momentum * means
|
||||
running_var = (1 - momentum) * running_var + momentum * variances
|
||||
self.assertTrue(np.allclose(batch_norm.running_mean, running_mean, atol=1e-5))
|
||||
self.assertTrue(np.allclose(batch_norm.running_var, running_var, atol=1e-5))
|
||||
self.assertTrue(mx.allclose(batch_norm.running_mean, running_mean, atol=1e-5))
|
||||
self.assertTrue(mx.allclose(batch_norm.running_var, running_var, atol=1e-5))
|
||||
|
||||
self.assertEqual(batch_norm.running_mean.shape, running_mean.shape)
|
||||
self.assertEqual(batch_norm.running_var.shape, running_var.shape)
|
||||
|
||||
def test_conv1d(self):
|
||||
N = 5
|
||||
|
Loading…
Reference in New Issue
Block a user