cleanup stats test

This commit is contained in:
Awni Hannun 2023-12-25 07:27:14 -08:00
parent 865e53fcab
commit 6b4f49fe1c

View File

@ -374,7 +374,7 @@ class TestNN(mlx_tests.MLXTestCase):
] ]
) )
self.assertTrue(x.shape == y.shape) self.assertTrue(x.shape == y.shape)
self.assertTrue(np.allclose(y, expected_y, atol=1e-5)) self.assertTrue(mx.allclose(y, expected_y, atol=1e-5))
# test with 3D input # test with 3D input
mx.random.seed(42) mx.random.seed(42)
@ -418,58 +418,44 @@ class TestNN(mlx_tests.MLXTestCase):
y = bn(x) y = bn(x)
def test_batch_norm_stats(self): def test_batch_norm_stats(self):
batch_size = 4 batch_size = 2
num_features = 32 num_features = 4
num_channels = 32 h = 3
h = 28 w = 3
w = 28
num_iterations = 100
momentum = 0.1 momentum = 0.1
batch_norm = nn.BatchNorm(num_features) batch_norm = nn.BatchNorm(num_features)
batch_norm.train() batch_norm.train()
running_mean = np.array(batch_norm._running_mean.tolist()) running_mean = np.array(batch_norm._running_mean)
running_var = np.array(batch_norm._running_var.tolist()) running_var = np.array(batch_norm._running_var)
data = mx.random.normal((batch_size * num_features,)).reshape( data = mx.random.normal((batch_size, num_features))
(batch_size, num_features)
)
for _ in range(num_iterations): normalized_data = batch_norm(data)
normalized_data = batch_norm(data) np_data = np.array(data)
means = np.mean(data.tolist(), axis=0) means = np.mean(np_data, axis=0)
variances = np.var(data.tolist(), axis=0) variances = np.var(np_data, axis=0)
running_mean = (1 - momentum) * running_mean + momentum * means running_mean = (1 - momentum) * running_mean + momentum * means
running_var = (1 - momentum) * running_var + momentum * variances running_var = (1 - momentum) * running_var + momentum * variances
assert np.allclose(batch_norm._running_mean, running_mean, atol=1e-5) self.assertTrue(np.allclose(batch_norm._running_mean, running_mean, atol=1e-5))
assert np.allclose(batch_norm._running_var, running_var, atol=1e-5) self.assertTrue(np.allclose(batch_norm._running_var, running_var, atol=1e-5))
data = normalized_data
batch_norm = nn.BatchNorm(num_channels) batch_norm = nn.BatchNorm(num_features)
batch_norm.train() batch_norm.train()
running_mean = np.array(batch_norm._running_mean.tolist()).reshape( running_mean = np.array(batch_norm._running_mean)
1, 1, 1, num_channels running_var = np.array(batch_norm._running_var)
) data = mx.random.normal((batch_size, h, w, num_features))
running_var = np.array(batch_norm._running_var.tolist()).reshape(
1, 1, 1, num_channels
)
data = mx.random.normal((batch_size, h, w, num_channels))
for _ in range(num_iterations): normalized_data = batch_norm(data)
normalized_data = batch_norm(data) np_data = np.array(data)
means = np.mean(data.tolist(), axis=(0, 1, 2)).reshape( means = np.mean(np_data, axis=(0, 1, 2))
1, 1, 1, num_channels variances = np.var(np_data, axis=(0, 1, 2))
) running_mean = (1 - momentum) * running_mean + momentum * means
variances = np.var(data.tolist(), axis=(0, 1, 2)).reshape( running_var = (1 - momentum) * running_var + momentum * variances
1, 1, 1, num_channels self.assertTrue(np.allclose(batch_norm._running_mean, running_mean, atol=1e-5))
) self.assertTrue(np.allclose(batch_norm._running_var, running_var, atol=1e-5))
running_mean = (1 - momentum) * running_mean + momentum * means
running_var = (1 - momentum) * running_var + momentum * variances
assert np.allclose(batch_norm._running_mean, running_mean, atol=1e-5)
assert np.allclose(batch_norm._running_var, running_var, atol=1e-5)
data = normalized_data
def test_conv1d(self): def test_conv1d(self):
N = 5 N = 5