mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
refactored and updated batch norm tests ^^
This commit is contained in:
parent
8b08f440d9
commit
a43b853194
@ -181,7 +181,7 @@ class GroupNorm(Module):
|
|||||||
x = group_norm(x)
|
x = group_norm(x)
|
||||||
return (self.weight * x + self.bias) if "weight" in self else x
|
return (self.weight * x + self.bias) if "weight" in self else x
|
||||||
|
|
||||||
|
|
||||||
class BatchNorm(Module):
|
class BatchNorm(Module):
|
||||||
r"""Applies Batch Normalization over a 2D or 3D input.
|
r"""Applies Batch Normalization over a 2D or 3D input.
|
||||||
|
|
||||||
@ -211,7 +211,7 @@ class BatchNorm(Module):
|
|||||||
>>> bn = nn.BatchNorm1d(num_features=4, affine=True)
|
>>> bn = nn.BatchNorm1d(num_features=4, affine=True)
|
||||||
>>> output = bn(x)
|
>>> output = bn(x)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_features: int,
|
num_features: int,
|
||||||
@ -239,7 +239,7 @@ class BatchNorm(Module):
|
|||||||
|
|
||||||
def _extra_repr(self):
|
def _extra_repr(self):
|
||||||
return f"{self.num_features}, eps={self.eps}, momentum={self.momentum}, affine={'weight' in self}, track_running_stats={self.track_running_stats}"
|
return f"{self.num_features}, eps={self.eps}, momentum={self.momentum}, affine={'weight' in self}, track_running_stats={self.track_running_stats}"
|
||||||
|
|
||||||
def _check_and_expand_dims(self, x: mx.array):
|
def _check_and_expand_dims(self, x: mx.array):
|
||||||
"""
|
"""
|
||||||
Check if the input is a 2D or 3D tensor and expand the weight, bias, running mean, and running variance accordingly.
|
Check if the input is a 2D or 3D tensor and expand the weight, bias, running mean, and running variance accordingly.
|
||||||
@ -247,7 +247,7 @@ class BatchNorm(Module):
|
|||||||
Args:
|
Args:
|
||||||
x (mx.array): Input tensor.
|
x (mx.array): Input tensor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
num_dims = len(x.shape)
|
num_dims = len(x.shape)
|
||||||
dims_dict = {
|
dims_dict = {
|
||||||
2: ((1, self.num_features), (0,)),
|
2: ((1, self.num_features), (0,)),
|
||||||
@ -259,17 +259,16 @@ class BatchNorm(Module):
|
|||||||
raise ValueError(f"expected num_dims to be 2, 3, or 4 (got {num_dims})")
|
raise ValueError(f"expected num_dims to be 2, 3, or 4 (got {num_dims})")
|
||||||
|
|
||||||
shape, self.reduction_axes = dims_dict[num_dims]
|
shape, self.reduction_axes = dims_dict[num_dims]
|
||||||
|
|
||||||
if self.affine:
|
if self.affine:
|
||||||
self.weight = mx.expand_dims(self.weight, self.reduction_axes)
|
self.weight = mx.expand_dims(self.weight, self.reduction_axes)
|
||||||
self.bias = mx.expand_dims(self.bias, self.reduction_axes)
|
self.bias = mx.expand_dims(self.bias, self.reduction_axes)
|
||||||
|
|
||||||
if self.track_running_stats:
|
if self.track_running_stats:
|
||||||
self.running_mean = mx.expand_dims(self.running_mean, self.reduction_axes)
|
self.running_mean = mx.expand_dims(self.running_mean, self.reduction_axes)
|
||||||
self.running_var = mx.expand_dims(self.running_var, self.reduction_axes)
|
self.running_var = mx.expand_dims(self.running_var, self.reduction_axes)
|
||||||
|
|
||||||
self.dims_expanded = True
|
|
||||||
|
|
||||||
|
self.dims_expanded = True
|
||||||
|
|
||||||
def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]:
|
def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]:
|
||||||
"""
|
"""
|
||||||
@ -304,7 +303,7 @@ class BatchNorm(Module):
|
|||||||
Returns:
|
Returns:
|
||||||
mx.array: Output tensor.
|
mx.array: Output tensor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not self.dims_expanded:
|
if not self.dims_expanded:
|
||||||
self._check_and_expand_dims(x)
|
self._check_and_expand_dims(x)
|
||||||
|
|
||||||
@ -313,4 +312,4 @@ class BatchNorm(Module):
|
|||||||
else:
|
else:
|
||||||
means, var = self.running_mean, self.running_var
|
means, var = self.running_mean, self.running_var
|
||||||
x = (x - means) * mx.rsqrt(var + self.eps)
|
x = (x - means) * mx.rsqrt(var + self.eps)
|
||||||
return (self.weight * x + self.bias) if "weight" in self else x
|
return (self.weight * x + self.bias) if "weight" in self else x
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -410,6 +411,60 @@ class TestNN(mlx_tests.MLXTestCase):
|
|||||||
self.assertTrue(np.allclose(bn.running_mean, expected_mean, atol=1e-5))
|
self.assertTrue(np.allclose(bn.running_mean, expected_mean, atol=1e-5))
|
||||||
self.assertTrue(np.allclose(bn.running_var, expected_var, atol=1e-5))
|
self.assertTrue(np.allclose(bn.running_var, expected_var, atol=1e-5))
|
||||||
|
|
||||||
|
def test_batch_norm_stats(self):
|
||||||
|
batch_size = 4
|
||||||
|
num_features = 32
|
||||||
|
num_channels = 32
|
||||||
|
h = 28
|
||||||
|
w = 28
|
||||||
|
num_iterations = 100
|
||||||
|
momentum = 0.1
|
||||||
|
|
||||||
|
batch_norm = nn.BatchNorm(num_features)
|
||||||
|
|
||||||
|
batch_norm.train()
|
||||||
|
running_mean = np.array(batch_norm.running_mean.tolist())
|
||||||
|
running_var = np.array(batch_norm.running_var.tolist())
|
||||||
|
|
||||||
|
data = mx.random.normal((batch_size * num_features,)).reshape(
|
||||||
|
(batch_size, num_features)
|
||||||
|
)
|
||||||
|
|
||||||
|
for _ in range(num_iterations):
|
||||||
|
normalized_data = batch_norm(data)
|
||||||
|
means = np.mean(data.tolist(), axis=0)
|
||||||
|
variances = np.var(data.tolist(), axis=0)
|
||||||
|
running_mean = (1 - momentum) * running_mean + momentum * means
|
||||||
|
running_var = (1 - momentum) * running_var + momentum * variances
|
||||||
|
assert np.allclose(batch_norm.running_mean, running_mean, atol=1e-5)
|
||||||
|
assert np.allclose(batch_norm.running_var, running_var, atol=1e-5)
|
||||||
|
data = normalized_data
|
||||||
|
|
||||||
|
batch_norm = nn.BatchNorm(num_channels)
|
||||||
|
|
||||||
|
batch_norm.train()
|
||||||
|
running_mean = np.array(batch_norm.running_mean.tolist()).reshape(
|
||||||
|
1, num_channels, 1, 1
|
||||||
|
)
|
||||||
|
running_var = np.array(batch_norm.running_var.tolist()).reshape(
|
||||||
|
1, num_channels, 1, 1
|
||||||
|
)
|
||||||
|
data = mx.random.normal((batch_size, num_channels, h, w))
|
||||||
|
|
||||||
|
for _ in range(num_iterations):
|
||||||
|
normalized_data = batch_norm(data)
|
||||||
|
means = np.mean(data.tolist(), axis=(0, 2, 3)).reshape(
|
||||||
|
1, num_channels, 1, 1
|
||||||
|
)
|
||||||
|
variances = np.var(data.tolist(), axis=(0, 2, 3)).reshape(
|
||||||
|
1, num_channels, 1, 1
|
||||||
|
)
|
||||||
|
running_mean = (1 - momentum) * running_mean + momentum * means
|
||||||
|
running_var = (1 - momentum) * running_var + momentum * variances
|
||||||
|
assert np.allclose(batch_norm.running_mean, running_mean, atol=1e-5)
|
||||||
|
assert np.allclose(batch_norm.running_var, running_var, atol=1e-5)
|
||||||
|
data = normalized_data
|
||||||
|
|
||||||
def test_conv1d(self):
|
def test_conv1d(self):
|
||||||
N = 5
|
N = 5
|
||||||
L = 12
|
L = 12
|
||||||
|
Loading…
Reference in New Issue
Block a user