mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
Fix BN stats to not expand shape (#409)
* fix BN stats to not expand shape * nit
This commit is contained in:
parent
753867123d
commit
e9ca65c939
@ -333,8 +333,8 @@ class BatchNorm(Module):
|
|||||||
"""
|
"""
|
||||||
reduction_axes = tuple(range(0, x.ndim - 1))
|
reduction_axes = tuple(range(0, x.ndim - 1))
|
||||||
|
|
||||||
mean = mx.mean(x, axis=reduction_axes, keepdims=True)
|
mean = mx.mean(x, axis=reduction_axes)
|
||||||
var = mx.var(x, axis=reduction_axes, keepdims=True)
|
var = mx.var(x, axis=reduction_axes)
|
||||||
|
|
||||||
return mean, var
|
return mean, var
|
||||||
|
|
||||||
|
@ -517,35 +517,36 @@ class TestLayers(mlx_tests.MLXTestCase):
|
|||||||
batch_norm = nn.BatchNorm(num_features)
|
batch_norm = nn.BatchNorm(num_features)
|
||||||
|
|
||||||
batch_norm.train()
|
batch_norm.train()
|
||||||
running_mean = np.array(batch_norm.running_mean)
|
running_mean = batch_norm.running_mean
|
||||||
running_var = np.array(batch_norm.running_var)
|
running_var = batch_norm.running_var
|
||||||
|
|
||||||
data = mx.random.normal((batch_size, num_features))
|
data = mx.random.normal((batch_size, num_features))
|
||||||
|
|
||||||
normalized_data = batch_norm(data)
|
normalized_data = batch_norm(data)
|
||||||
np_data = np.array(data)
|
means = mx.mean(data, axis=0)
|
||||||
means = np.mean(np_data, axis=0)
|
variances = mx.var(data, axis=0)
|
||||||
variances = np.var(np_data, axis=0)
|
|
||||||
running_mean = (1 - momentum) * running_mean + momentum * means
|
running_mean = (1 - momentum) * running_mean + momentum * means
|
||||||
running_var = (1 - momentum) * running_var + momentum * variances
|
running_var = (1 - momentum) * running_var + momentum * variances
|
||||||
self.assertTrue(np.allclose(batch_norm.running_mean, running_mean, atol=1e-5))
|
self.assertTrue(mx.allclose(batch_norm.running_mean, running_mean, atol=1e-5))
|
||||||
self.assertTrue(np.allclose(batch_norm.running_var, running_var, atol=1e-5))
|
self.assertTrue(mx.allclose(batch_norm.running_var, running_var, atol=1e-5))
|
||||||
|
|
||||||
batch_norm = nn.BatchNorm(num_features)
|
batch_norm = nn.BatchNorm(num_features)
|
||||||
|
|
||||||
batch_norm.train()
|
batch_norm.train()
|
||||||
running_mean = np.array(batch_norm.running_mean)
|
running_mean = batch_norm.running_mean
|
||||||
running_var = np.array(batch_norm.running_var)
|
running_var = batch_norm.running_var
|
||||||
data = mx.random.normal((batch_size, h, w, num_features))
|
data = mx.random.normal((batch_size, h, w, num_features))
|
||||||
|
|
||||||
normalized_data = batch_norm(data)
|
normalized_data = batch_norm(data)
|
||||||
np_data = np.array(data)
|
means = mx.mean(data, axis=(0, 1, 2))
|
||||||
means = np.mean(np_data, axis=(0, 1, 2))
|
variances = mx.var(data, axis=(0, 1, 2))
|
||||||
variances = np.var(np_data, axis=(0, 1, 2))
|
|
||||||
running_mean = (1 - momentum) * running_mean + momentum * means
|
running_mean = (1 - momentum) * running_mean + momentum * means
|
||||||
running_var = (1 - momentum) * running_var + momentum * variances
|
running_var = (1 - momentum) * running_var + momentum * variances
|
||||||
self.assertTrue(np.allclose(batch_norm.running_mean, running_mean, atol=1e-5))
|
self.assertTrue(mx.allclose(batch_norm.running_mean, running_mean, atol=1e-5))
|
||||||
self.assertTrue(np.allclose(batch_norm.running_var, running_var, atol=1e-5))
|
self.assertTrue(mx.allclose(batch_norm.running_var, running_var, atol=1e-5))
|
||||||
|
|
||||||
|
self.assertEqual(batch_norm.running_mean.shape, running_mean.shape)
|
||||||
|
self.assertEqual(batch_norm.running_var.shape, running_var.shape)
|
||||||
|
|
||||||
def test_conv1d(self):
|
def test_conv1d(self):
|
||||||
N = 5
|
N = 5
|
||||||
|
Loading…
Reference in New Issue
Block a user