refactored and updated batch norm tests ^^

This commit is contained in:
m0saan 2023-12-22 20:50:05 +01:00
parent 8b08f440d9
commit a43b853194
2 changed files with 64 additions and 10 deletions

View File

@ -181,7 +181,7 @@ class GroupNorm(Module):
x = group_norm(x) x = group_norm(x)
return (self.weight * x + self.bias) if "weight" in self else x return (self.weight * x + self.bias) if "weight" in self else x
class BatchNorm(Module): class BatchNorm(Module):
r"""Applies Batch Normalization over a 2D or 3D input. r"""Applies Batch Normalization over a 2D or 3D input.
@ -211,7 +211,7 @@ class BatchNorm(Module):
>>> bn = nn.BatchNorm1d(num_features=4, affine=True) >>> bn = nn.BatchNorm1d(num_features=4, affine=True)
>>> output = bn(x) >>> output = bn(x)
""" """
def __init__( def __init__(
self, self,
num_features: int, num_features: int,
@ -239,7 +239,7 @@ class BatchNorm(Module):
def _extra_repr(self): def _extra_repr(self):
return f"{self.num_features}, eps={self.eps}, momentum={self.momentum}, affine={'weight' in self}, track_running_stats={self.track_running_stats}" return f"{self.num_features}, eps={self.eps}, momentum={self.momentum}, affine={'weight' in self}, track_running_stats={self.track_running_stats}"
def _check_and_expand_dims(self, x: mx.array): def _check_and_expand_dims(self, x: mx.array):
""" """
Check if the input is a 2D or 3D tensor and expand the weight, bias, running mean, and running variance accordingly. Check if the input is a 2D or 3D tensor and expand the weight, bias, running mean, and running variance accordingly.
@ -247,7 +247,7 @@ class BatchNorm(Module):
Args: Args:
x (mx.array): Input tensor. x (mx.array): Input tensor.
""" """
num_dims = len(x.shape) num_dims = len(x.shape)
dims_dict = { dims_dict = {
2: ((1, self.num_features), (0,)), 2: ((1, self.num_features), (0,)),
@ -259,17 +259,16 @@ class BatchNorm(Module):
raise ValueError(f"expected num_dims to be 2, 3, or 4 (got {num_dims})") raise ValueError(f"expected num_dims to be 2, 3, or 4 (got {num_dims})")
shape, self.reduction_axes = dims_dict[num_dims] shape, self.reduction_axes = dims_dict[num_dims]
if self.affine: if self.affine:
self.weight = mx.expand_dims(self.weight, self.reduction_axes) self.weight = mx.expand_dims(self.weight, self.reduction_axes)
self.bias = mx.expand_dims(self.bias, self.reduction_axes) self.bias = mx.expand_dims(self.bias, self.reduction_axes)
if self.track_running_stats: if self.track_running_stats:
self.running_mean = mx.expand_dims(self.running_mean, self.reduction_axes) self.running_mean = mx.expand_dims(self.running_mean, self.reduction_axes)
self.running_var = mx.expand_dims(self.running_var, self.reduction_axes) self.running_var = mx.expand_dims(self.running_var, self.reduction_axes)
self.dims_expanded = True
self.dims_expanded = True
def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]: def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]:
""" """
@ -304,7 +303,7 @@ class BatchNorm(Module):
Returns: Returns:
mx.array: Output tensor. mx.array: Output tensor.
""" """
if not self.dims_expanded: if not self.dims_expanded:
self._check_and_expand_dims(x) self._check_and_expand_dims(x)
@ -313,4 +312,4 @@ class BatchNorm(Module):
else: else:
means, var = self.running_mean, self.running_var means, var = self.running_mean, self.running_var
x = (x - means) * mx.rsqrt(var + self.eps) x = (x - means) * mx.rsqrt(var + self.eps)
return (self.weight * x + self.bias) if "weight" in self else x return (self.weight * x + self.bias) if "weight" in self else x

View File

@ -3,6 +3,7 @@
import os import os
import tempfile import tempfile
import unittest import unittest
from unittest.mock import Mock, patch
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -410,6 +411,60 @@ class TestNN(mlx_tests.MLXTestCase):
self.assertTrue(np.allclose(bn.running_mean, expected_mean, atol=1e-5)) self.assertTrue(np.allclose(bn.running_mean, expected_mean, atol=1e-5))
self.assertTrue(np.allclose(bn.running_var, expected_var, atol=1e-5)) self.assertTrue(np.allclose(bn.running_var, expected_var, atol=1e-5))
def test_batch_norm_stats(self):
batch_size = 4
num_features = 32
num_channels = 32
h = 28
w = 28
num_iterations = 100
momentum = 0.1
batch_norm = nn.BatchNorm(num_features)
batch_norm.train()
running_mean = np.array(batch_norm.running_mean.tolist())
running_var = np.array(batch_norm.running_var.tolist())
data = mx.random.normal((batch_size * num_features,)).reshape(
(batch_size, num_features)
)
for _ in range(num_iterations):
normalized_data = batch_norm(data)
means = np.mean(data.tolist(), axis=0)
variances = np.var(data.tolist(), axis=0)
running_mean = (1 - momentum) * running_mean + momentum * means
running_var = (1 - momentum) * running_var + momentum * variances
assert np.allclose(batch_norm.running_mean, running_mean, atol=1e-5)
assert np.allclose(batch_norm.running_var, running_var, atol=1e-5)
data = normalized_data
batch_norm = nn.BatchNorm(num_channels)
batch_norm.train()
running_mean = np.array(batch_norm.running_mean.tolist()).reshape(
1, num_channels, 1, 1
)
running_var = np.array(batch_norm.running_var.tolist()).reshape(
1, num_channels, 1, 1
)
data = mx.random.normal((batch_size, num_channels, h, w))
for _ in range(num_iterations):
normalized_data = batch_norm(data)
means = np.mean(data.tolist(), axis=(0, 2, 3)).reshape(
1, num_channels, 1, 1
)
variances = np.var(data.tolist(), axis=(0, 2, 3)).reshape(
1, num_channels, 1, 1
)
running_mean = (1 - momentum) * running_mean + momentum * means
running_var = (1 - momentum) * running_var + momentum * variances
assert np.allclose(batch_norm.running_mean, running_mean, atol=1e-5)
assert np.allclose(batch_norm.running_var, running_var, atol=1e-5)
data = normalized_data
def test_conv1d(self): def test_conv1d(self):
N = 5 N = 5
L = 12 L = 12