mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
refactored and updated batch norm tests ^^
This commit is contained in:
parent
8b08f440d9
commit
a43b853194
@ -270,7 +270,6 @@ class BatchNorm(Module):
|
||||
|
||||
self.dims_expanded = True
|
||||
|
||||
|
||||
def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]:
|
||||
"""
|
||||
Calculate the mean and variance of the input tensor.
|
||||
|
@ -3,6 +3,7 @@
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@ -410,6 +411,60 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(np.allclose(bn.running_mean, expected_mean, atol=1e-5))
|
||||
self.assertTrue(np.allclose(bn.running_var, expected_var, atol=1e-5))
|
||||
|
||||
def test_batch_norm_stats(self):
|
||||
batch_size = 4
|
||||
num_features = 32
|
||||
num_channels = 32
|
||||
h = 28
|
||||
w = 28
|
||||
num_iterations = 100
|
||||
momentum = 0.1
|
||||
|
||||
batch_norm = nn.BatchNorm(num_features)
|
||||
|
||||
batch_norm.train()
|
||||
running_mean = np.array(batch_norm.running_mean.tolist())
|
||||
running_var = np.array(batch_norm.running_var.tolist())
|
||||
|
||||
data = mx.random.normal((batch_size * num_features,)).reshape(
|
||||
(batch_size, num_features)
|
||||
)
|
||||
|
||||
for _ in range(num_iterations):
|
||||
normalized_data = batch_norm(data)
|
||||
means = np.mean(data.tolist(), axis=0)
|
||||
variances = np.var(data.tolist(), axis=0)
|
||||
running_mean = (1 - momentum) * running_mean + momentum * means
|
||||
running_var = (1 - momentum) * running_var + momentum * variances
|
||||
assert np.allclose(batch_norm.running_mean, running_mean, atol=1e-5)
|
||||
assert np.allclose(batch_norm.running_var, running_var, atol=1e-5)
|
||||
data = normalized_data
|
||||
|
||||
batch_norm = nn.BatchNorm(num_channels)
|
||||
|
||||
batch_norm.train()
|
||||
running_mean = np.array(batch_norm.running_mean.tolist()).reshape(
|
||||
1, num_channels, 1, 1
|
||||
)
|
||||
running_var = np.array(batch_norm.running_var.tolist()).reshape(
|
||||
1, num_channels, 1, 1
|
||||
)
|
||||
data = mx.random.normal((batch_size, num_channels, h, w))
|
||||
|
||||
for _ in range(num_iterations):
|
||||
normalized_data = batch_norm(data)
|
||||
means = np.mean(data.tolist(), axis=(0, 2, 3)).reshape(
|
||||
1, num_channels, 1, 1
|
||||
)
|
||||
variances = np.var(data.tolist(), axis=(0, 2, 3)).reshape(
|
||||
1, num_channels, 1, 1
|
||||
)
|
||||
running_mean = (1 - momentum) * running_mean + momentum * means
|
||||
running_var = (1 - momentum) * running_var + momentum * variances
|
||||
assert np.allclose(batch_norm.running_mean, running_mean, atol=1e-5)
|
||||
assert np.allclose(batch_norm.running_var, running_var, atol=1e-5)
|
||||
data = normalized_data
|
||||
|
||||
def test_conv1d(self):
|
||||
N = 5
|
||||
L = 12
|
||||
|
Loading…
Reference in New Issue
Block a user