refactored and updated batch norm tests ^^

This commit is contained in:
m0saan 2023-12-22 20:50:05 +01:00
parent 8b08f440d9
commit a43b853194
2 changed files with 64 additions and 10 deletions

View File

@ -270,7 +270,6 @@ class BatchNorm(Module):
self.dims_expanded = True self.dims_expanded = True
def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]: def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]:
""" """
Calculate the mean and variance of the input tensor. Calculate the mean and variance of the input tensor.

View File

@ -3,6 +3,7 @@
import os import os
import tempfile import tempfile
import unittest import unittest
from unittest.mock import Mock, patch
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -410,6 +411,60 @@ class TestNN(mlx_tests.MLXTestCase):
self.assertTrue(np.allclose(bn.running_mean, expected_mean, atol=1e-5)) self.assertTrue(np.allclose(bn.running_mean, expected_mean, atol=1e-5))
self.assertTrue(np.allclose(bn.running_var, expected_var, atol=1e-5)) self.assertTrue(np.allclose(bn.running_var, expected_var, atol=1e-5))
def test_batch_norm_stats(self):
batch_size = 4
num_features = 32
num_channels = 32
h = 28
w = 28
num_iterations = 100
momentum = 0.1
batch_norm = nn.BatchNorm(num_features)
batch_norm.train()
running_mean = np.array(batch_norm.running_mean.tolist())
running_var = np.array(batch_norm.running_var.tolist())
data = mx.random.normal((batch_size * num_features,)).reshape(
(batch_size, num_features)
)
for _ in range(num_iterations):
normalized_data = batch_norm(data)
means = np.mean(data.tolist(), axis=0)
variances = np.var(data.tolist(), axis=0)
running_mean = (1 - momentum) * running_mean + momentum * means
running_var = (1 - momentum) * running_var + momentum * variances
assert np.allclose(batch_norm.running_mean, running_mean, atol=1e-5)
assert np.allclose(batch_norm.running_var, running_var, atol=1e-5)
data = normalized_data
batch_norm = nn.BatchNorm(num_channels)
batch_norm.train()
running_mean = np.array(batch_norm.running_mean.tolist()).reshape(
1, num_channels, 1, 1
)
running_var = np.array(batch_norm.running_var.tolist()).reshape(
1, num_channels, 1, 1
)
data = mx.random.normal((batch_size, num_channels, h, w))
for _ in range(num_iterations):
normalized_data = batch_norm(data)
means = np.mean(data.tolist(), axis=(0, 2, 3)).reshape(
1, num_channels, 1, 1
)
variances = np.var(data.tolist(), axis=(0, 2, 3)).reshape(
1, num_channels, 1, 1
)
running_mean = (1 - momentum) * running_mean + momentum * means
running_var = (1 - momentum) * running_var + momentum * variances
assert np.allclose(batch_norm.running_mean, running_mean, atol=1e-5)
assert np.allclose(batch_norm.running_var, running_var, atol=1e-5)
data = normalized_data
def test_conv1d(self): def test_conv1d(self):
N = 5 N = 5
L = 12 L = 12