mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-22 04:56:41 +08:00
cleanup stats test
This commit is contained in:
parent
865e53fcab
commit
6b4f49fe1c
@ -374,7 +374,7 @@ class TestNN(mlx_tests.MLXTestCase):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.assertTrue(x.shape == y.shape)
|
self.assertTrue(x.shape == y.shape)
|
||||||
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
|
self.assertTrue(mx.allclose(y, expected_y, atol=1e-5))
|
||||||
|
|
||||||
# test with 3D input
|
# test with 3D input
|
||||||
mx.random.seed(42)
|
mx.random.seed(42)
|
||||||
@ -418,58 +418,44 @@ class TestNN(mlx_tests.MLXTestCase):
|
|||||||
y = bn(x)
|
y = bn(x)
|
||||||
|
|
||||||
def test_batch_norm_stats(self):
|
def test_batch_norm_stats(self):
|
||||||
batch_size = 4
|
batch_size = 2
|
||||||
num_features = 32
|
num_features = 4
|
||||||
num_channels = 32
|
h = 3
|
||||||
h = 28
|
w = 3
|
||||||
w = 28
|
|
||||||
num_iterations = 100
|
|
||||||
momentum = 0.1
|
momentum = 0.1
|
||||||
|
|
||||||
batch_norm = nn.BatchNorm(num_features)
|
batch_norm = nn.BatchNorm(num_features)
|
||||||
|
|
||||||
batch_norm.train()
|
batch_norm.train()
|
||||||
running_mean = np.array(batch_norm._running_mean.tolist())
|
running_mean = np.array(batch_norm._running_mean)
|
||||||
running_var = np.array(batch_norm._running_var.tolist())
|
running_var = np.array(batch_norm._running_var)
|
||||||
|
|
||||||
data = mx.random.normal((batch_size * num_features,)).reshape(
|
data = mx.random.normal((batch_size, num_features))
|
||||||
(batch_size, num_features)
|
|
||||||
)
|
|
||||||
|
|
||||||
for _ in range(num_iterations):
|
normalized_data = batch_norm(data)
|
||||||
normalized_data = batch_norm(data)
|
np_data = np.array(data)
|
||||||
means = np.mean(data.tolist(), axis=0)
|
means = np.mean(np_data, axis=0)
|
||||||
variances = np.var(data.tolist(), axis=0)
|
variances = np.var(np_data, axis=0)
|
||||||
running_mean = (1 - momentum) * running_mean + momentum * means
|
running_mean = (1 - momentum) * running_mean + momentum * means
|
||||||
running_var = (1 - momentum) * running_var + momentum * variances
|
running_var = (1 - momentum) * running_var + momentum * variances
|
||||||
assert np.allclose(batch_norm._running_mean, running_mean, atol=1e-5)
|
self.assertTrue(np.allclose(batch_norm._running_mean, running_mean, atol=1e-5))
|
||||||
assert np.allclose(batch_norm._running_var, running_var, atol=1e-5)
|
self.assertTrue(np.allclose(batch_norm._running_var, running_var, atol=1e-5))
|
||||||
data = normalized_data
|
|
||||||
|
|
||||||
batch_norm = nn.BatchNorm(num_channels)
|
batch_norm = nn.BatchNorm(num_features)
|
||||||
|
|
||||||
batch_norm.train()
|
batch_norm.train()
|
||||||
running_mean = np.array(batch_norm._running_mean.tolist()).reshape(
|
running_mean = np.array(batch_norm._running_mean)
|
||||||
1, 1, 1, num_channels
|
running_var = np.array(batch_norm._running_var)
|
||||||
)
|
data = mx.random.normal((batch_size, h, w, num_features))
|
||||||
running_var = np.array(batch_norm._running_var.tolist()).reshape(
|
|
||||||
1, 1, 1, num_channels
|
|
||||||
)
|
|
||||||
data = mx.random.normal((batch_size, h, w, num_channels))
|
|
||||||
|
|
||||||
for _ in range(num_iterations):
|
normalized_data = batch_norm(data)
|
||||||
normalized_data = batch_norm(data)
|
np_data = np.array(data)
|
||||||
means = np.mean(data.tolist(), axis=(0, 1, 2)).reshape(
|
means = np.mean(np_data, axis=(0, 1, 2))
|
||||||
1, 1, 1, num_channels
|
variances = np.var(np_data, axis=(0, 1, 2))
|
||||||
)
|
running_mean = (1 - momentum) * running_mean + momentum * means
|
||||||
variances = np.var(data.tolist(), axis=(0, 1, 2)).reshape(
|
running_var = (1 - momentum) * running_var + momentum * variances
|
||||||
1, 1, 1, num_channels
|
self.assertTrue(np.allclose(batch_norm._running_mean, running_mean, atol=1e-5))
|
||||||
)
|
self.assertTrue(np.allclose(batch_norm._running_var, running_var, atol=1e-5))
|
||||||
running_mean = (1 - momentum) * running_mean + momentum * means
|
|
||||||
running_var = (1 - momentum) * running_var + momentum * variances
|
|
||||||
assert np.allclose(batch_norm._running_mean, running_mean, atol=1e-5)
|
|
||||||
assert np.allclose(batch_norm._running_var, running_var, atol=1e-5)
|
|
||||||
data = normalized_data
|
|
||||||
|
|
||||||
def test_conv1d(self):
|
def test_conv1d(self):
|
||||||
N = 5
|
N = 5
|
||||||
|
Loading…
Reference in New Issue
Block a user