From 6b4f49fe1cd45f508cc9455c1f99ab3c00f67238 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 25 Dec 2023 07:27:14 -0800 Subject: [PATCH] cleanup stats test --- python/tests/test_nn.py | 70 +++++++++++++++++------------------------ 1 file changed, 28 insertions(+), 42 deletions(-) diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index d2c83851e..2cfac4475 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -374,7 +374,7 @@ class TestNN(mlx_tests.MLXTestCase): ] ) self.assertTrue(x.shape == y.shape) - self.assertTrue(np.allclose(y, expected_y, atol=1e-5)) + self.assertTrue(mx.allclose(y, expected_y, atol=1e-5)) # test with 3D input mx.random.seed(42) @@ -418,58 +418,44 @@ class TestNN(mlx_tests.MLXTestCase): y = bn(x) def test_batch_norm_stats(self): - batch_size = 4 - num_features = 32 - num_channels = 32 - h = 28 - w = 28 - num_iterations = 100 + batch_size = 2 + num_features = 4 + h = 3 + w = 3 momentum = 0.1 batch_norm = nn.BatchNorm(num_features) batch_norm.train() - running_mean = np.array(batch_norm._running_mean.tolist()) - running_var = np.array(batch_norm._running_var.tolist()) + running_mean = np.array(batch_norm._running_mean) + running_var = np.array(batch_norm._running_var) - data = mx.random.normal((batch_size * num_features,)).reshape( - (batch_size, num_features) - ) + data = mx.random.normal((batch_size, num_features)) - for _ in range(num_iterations): - normalized_data = batch_norm(data) - means = np.mean(data.tolist(), axis=0) - variances = np.var(data.tolist(), axis=0) - running_mean = (1 - momentum) * running_mean + momentum * means - running_var = (1 - momentum) * running_var + momentum * variances - assert np.allclose(batch_norm._running_mean, running_mean, atol=1e-5) - assert np.allclose(batch_norm._running_var, running_var, atol=1e-5) - data = normalized_data + normalized_data = batch_norm(data) + np_data = np.array(data) + means = np.mean(np_data, axis=0) + variances = np.var(np_data, axis=0) + running_mean = (1 - momentum) * running_mean + momentum * means + running_var = (1 - momentum) * running_var + momentum * variances + self.assertTrue(np.allclose(batch_norm._running_mean, running_mean, atol=1e-5)) + self.assertTrue(np.allclose(batch_norm._running_var, running_var, atol=1e-5)) - batch_norm = nn.BatchNorm(num_channels) + batch_norm = nn.BatchNorm(num_features) batch_norm.train() - running_mean = np.array(batch_norm._running_mean.tolist()).reshape( - 1, 1, 1, num_channels - ) - running_var = np.array(batch_norm._running_var.tolist()).reshape( - 1, 1, 1, num_channels - ) - data = mx.random.normal((batch_size, h, w, num_channels)) + running_mean = np.array(batch_norm._running_mean) + running_var = np.array(batch_norm._running_var) + data = mx.random.normal((batch_size, h, w, num_features)) - for _ in range(num_iterations): - normalized_data = batch_norm(data) - means = np.mean(data.tolist(), axis=(0, 1, 2)).reshape( - 1, 1, 1, num_channels - ) - variances = np.var(data.tolist(), axis=(0, 1, 2)).reshape( - 1, 1, 1, num_channels - ) - running_mean = (1 - momentum) * running_mean + momentum * means - running_var = (1 - momentum) * running_var + momentum * variances - assert np.allclose(batch_norm._running_mean, running_mean, atol=1e-5) - assert np.allclose(batch_norm._running_var, running_var, atol=1e-5) - data = normalized_data + normalized_data = batch_norm(data) + np_data = np.array(data) + means = np.mean(np_data, axis=(0, 1, 2)) + variances = np.var(np_data, axis=(0, 1, 2)) + running_mean = (1 - momentum) * running_mean + momentum * means + running_var = (1 - momentum) * running_var + momentum * variances + self.assertTrue(np.allclose(batch_norm._running_mean, running_mean, atol=1e-5)) + self.assertTrue(np.allclose(batch_norm._running_var, running_var, atol=1e-5)) def test_conv1d(self): N = 5