mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
Merge branch 'main' into feature_expand_nn_linear
This commit is contained in:
commit
85da6e2626
@ -243,8 +243,15 @@ class BatchNorm(Module):
|
||||
self.bias = mx.zeros((num_features,))
|
||||
|
||||
if self.track_running_stats:
|
||||
self._running_mean = mx.zeros((num_features,))
|
||||
self._running_var = mx.ones((num_features,))
|
||||
self.running_mean = mx.zeros((num_features,))
|
||||
self.running_var = mx.ones((num_features,))
|
||||
self.freeze(keys=["running_mean", "running_var"], recurse=False)
|
||||
|
||||
def unfreeze(self, *args, **kwargs):
|
||||
"""Wrap unfreeze to make sure that running_mean and var are always
|
||||
frozen parameters."""
|
||||
super().unfreeze(*args, **kwargs)
|
||||
self.freeze(keys=["running_mean", "running_var"], recurse=False)
|
||||
|
||||
def _extra_repr(self):
|
||||
return (
|
||||
@ -255,46 +262,47 @@ class BatchNorm(Module):
|
||||
|
||||
def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]:
|
||||
"""
|
||||
Calculate the mean and variance of the input tensor.
|
||||
Calculate the mean and variance of the input tensor across the batch
|
||||
and spatial dimensions.
|
||||
|
||||
Args:
|
||||
x (mx.array): Input tensor.
|
||||
x (array): Input tensor.
|
||||
|
||||
Returns:
|
||||
tuple: Tuple containing mean and variance.
|
||||
"""
|
||||
reduction_axes = tuple(range(0, x.ndim - 1))
|
||||
means = mx.mean(x, axis=reduction_axes, keepdims=True)
|
||||
|
||||
mean = mx.mean(x, axis=reduction_axes, keepdims=True)
|
||||
var = mx.var(x, axis=reduction_axes, keepdims=True)
|
||||
|
||||
if self.track_running_stats and self.training:
|
||||
self._running_mean = (
|
||||
1 - self.momentum
|
||||
) * self._running_mean + self.momentum * means
|
||||
self._running_var = (
|
||||
1 - self.momentum
|
||||
) * self._running_var + self.momentum * var
|
||||
return means, var
|
||||
return mean, var
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
"""
|
||||
Forward pass of BatchNorm.
|
||||
|
||||
Args:
|
||||
x (mx.array): Input tensor.
|
||||
x (array): Input tensor.
|
||||
|
||||
Returns:
|
||||
mx.array: Output tensor.
|
||||
array: Normalized output tensor.
|
||||
"""
|
||||
|
||||
if x.ndim < 2 or x.ndim > 4:
|
||||
raise ValueError(
|
||||
f"Expected input tensor to have 2, 3 or 4 dimensions, but got {x.ndim}"
|
||||
)
|
||||
|
||||
if self.training or not self.track_running_stats:
|
||||
means, var = self._calc_stats(x)
|
||||
else:
|
||||
means, var = self._running_mean, self._running_var
|
||||
x = (x - means) * mx.rsqrt(var + self.eps)
|
||||
# Calculate the mean and variance used to normalize the input x. If we
|
||||
# are in training mode update the running stats if needed.
|
||||
mean, var = self._calc_stats(x)
|
||||
if self.training and self.track_running_stats:
|
||||
mu = self.momentum
|
||||
self.running_mean = (1 - mu) * self.running_mean + mu * mean
|
||||
self.running_var = (1 - mu) * self.running_var + mu * var
|
||||
elif self.track_running_stats:
|
||||
mean = self.running_mean
|
||||
var = self.running_var
|
||||
|
||||
x = (x - mean) * mx.rsqrt(var + self.eps)
|
||||
return (self.weight * x + self.bias) if "weight" in self else x
|
||||
|
@ -1,5 +1,7 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
def tree_map(fn, tree, *rest, is_leaf=None):
|
||||
"""Applies ``fn`` to the leaves of the python tree ``tree`` and
|
||||
@ -37,13 +39,9 @@ def tree_map(fn, tree, *rest, is_leaf=None):
|
||||
"""
|
||||
if is_leaf is not None and is_leaf(tree):
|
||||
return fn(tree, *rest)
|
||||
elif isinstance(tree, list):
|
||||
return [
|
||||
tree_map(fn, child, *(r[i] for r in rest), is_leaf=is_leaf)
|
||||
for i, child in enumerate(tree)
|
||||
]
|
||||
elif isinstance(tree, tuple):
|
||||
return tuple(
|
||||
elif isinstance(tree, (list, tuple)):
|
||||
TreeType = type(tree)
|
||||
return TreeType(
|
||||
tree_map(fn, child, *(r[i] for r in rest), is_leaf=is_leaf)
|
||||
for i, child in enumerate(tree)
|
||||
)
|
||||
@ -128,12 +126,10 @@ def tree_unflatten(tree):
|
||||
is_list = False
|
||||
|
||||
# collect children
|
||||
children = {}
|
||||
children = defaultdict(list)
|
||||
for key, value in tree:
|
||||
current_idx, *next_idx = key.split(".", maxsplit=1)
|
||||
next_idx = "" if not next_idx else next_idx[0]
|
||||
if current_idx not in children:
|
||||
children[current_idx] = []
|
||||
children[current_idx].append((next_idx, value))
|
||||
|
||||
# recursively map them to the original container
|
||||
@ -141,8 +137,8 @@ def tree_unflatten(tree):
|
||||
keys = sorted((int(idx), idx) for idx in children.keys())
|
||||
l = []
|
||||
for i, k in keys:
|
||||
while i > len(l):
|
||||
l.append({})
|
||||
# if i <= len(l), no {} will be appended.
|
||||
l.extend([{} for _ in range(i - len(l))])
|
||||
l.append(tree_unflatten(children[k]))
|
||||
return l
|
||||
else:
|
||||
|
@ -339,8 +339,8 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
|
||||
# Batch norm
|
||||
bn = nn.BatchNorm(num_features=4, affine=True)
|
||||
self.assertTrue(mx.allclose(bn._running_mean, mx.zeros_like(bn._running_mean)))
|
||||
self.assertTrue(mx.allclose(bn._running_var, mx.ones_like(bn._running_var)))
|
||||
self.assertTrue(mx.allclose(bn.running_mean, mx.zeros_like(bn.running_mean)))
|
||||
self.assertTrue(mx.allclose(bn.running_var, mx.ones_like(bn.running_var)))
|
||||
y = bn(x)
|
||||
expected_y = mx.array(
|
||||
[
|
||||
@ -355,8 +355,8 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
expected_var = mx.array([0.928435, 1.00455, 1.04117, 0.94258])
|
||||
self.assertTrue(x.shape == y.shape)
|
||||
self.assertTrue(mx.allclose(y, expected_y, atol=1e-5))
|
||||
self.assertTrue(mx.allclose(bn._running_mean, expected_mean, atol=1e-5))
|
||||
self.assertTrue(mx.allclose(bn._running_var, expected_var, atol=1e-5))
|
||||
self.assertTrue(mx.allclose(bn.running_mean, expected_mean, atol=1e-5))
|
||||
self.assertTrue(mx.allclose(bn.running_var, expected_var, atol=1e-5))
|
||||
|
||||
# test eval mode
|
||||
bn.eval()
|
||||
@ -398,8 +398,8 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
|
||||
# Batch norm
|
||||
bn = nn.BatchNorm(num_features=C, affine=True)
|
||||
self.assertTrue(mx.allclose(bn._running_mean, mx.zeros_like(bn._running_mean)))
|
||||
self.assertTrue(mx.allclose(bn._running_var, mx.ones_like(bn._running_var)))
|
||||
self.assertTrue(mx.allclose(bn.running_mean, mx.zeros_like(bn.running_mean)))
|
||||
self.assertTrue(mx.allclose(bn.running_var, mx.ones_like(bn.running_var)))
|
||||
y = bn(x)
|
||||
self.assertTrue(x.shape == y.shape)
|
||||
expected_y = mx.array(
|
||||
@ -423,13 +423,33 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
[[[0.00207845, -5.3259e-05, 0.04755, -0.0697296, 0.0236228]]]
|
||||
)
|
||||
expected_var = mx.array([[[0.968415, 1.05322, 0.96913, 0.932305, 0.967224]]])
|
||||
self.assertTrue(mx.allclose(bn._running_mean, expected_mean, atol=1e-5))
|
||||
self.assertTrue(mx.allclose(bn._running_var, expected_var, atol=1e-5))
|
||||
self.assertTrue(mx.allclose(bn.running_mean, expected_mean, atol=1e-5))
|
||||
self.assertTrue(mx.allclose(bn.running_var, expected_var, atol=1e-5))
|
||||
|
||||
x = mx.random.normal((N, L, C, L, C), dtype=mx.float32)
|
||||
with self.assertRaises(ValueError):
|
||||
y = bn(x)
|
||||
|
||||
# Check that the running stats are in the param dictionary
|
||||
bn_parameters = bn.parameters()
|
||||
self.assertIn("running_mean", bn_parameters)
|
||||
self.assertIn("running_var", bn_parameters)
|
||||
self.assertIn("weight", bn_parameters)
|
||||
self.assertIn("bias", bn_parameters)
|
||||
|
||||
bn_trainable = bn.trainable_parameters()
|
||||
self.assertNotIn("running_mean", bn_trainable)
|
||||
self.assertNotIn("running_var", bn_trainable)
|
||||
self.assertIn("weight", bn_trainable)
|
||||
self.assertIn("bias", bn_trainable)
|
||||
|
||||
bn.unfreeze()
|
||||
bn_trainable = bn.trainable_parameters()
|
||||
self.assertNotIn("running_mean", bn_trainable)
|
||||
self.assertNotIn("running_var", bn_trainable)
|
||||
self.assertIn("weight", bn_trainable)
|
||||
self.assertIn("bias", bn_trainable)
|
||||
|
||||
def test_batch_norm_stats(self):
|
||||
batch_size = 2
|
||||
num_features = 4
|
||||
@ -440,8 +460,8 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
batch_norm = nn.BatchNorm(num_features)
|
||||
|
||||
batch_norm.train()
|
||||
running_mean = np.array(batch_norm._running_mean)
|
||||
running_var = np.array(batch_norm._running_var)
|
||||
running_mean = np.array(batch_norm.running_mean)
|
||||
running_var = np.array(batch_norm.running_var)
|
||||
|
||||
data = mx.random.normal((batch_size, num_features))
|
||||
|
||||
@ -451,14 +471,14 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
variances = np.var(np_data, axis=0)
|
||||
running_mean = (1 - momentum) * running_mean + momentum * means
|
||||
running_var = (1 - momentum) * running_var + momentum * variances
|
||||
self.assertTrue(np.allclose(batch_norm._running_mean, running_mean, atol=1e-5))
|
||||
self.assertTrue(np.allclose(batch_norm._running_var, running_var, atol=1e-5))
|
||||
self.assertTrue(np.allclose(batch_norm.running_mean, running_mean, atol=1e-5))
|
||||
self.assertTrue(np.allclose(batch_norm.running_var, running_var, atol=1e-5))
|
||||
|
||||
batch_norm = nn.BatchNorm(num_features)
|
||||
|
||||
batch_norm.train()
|
||||
running_mean = np.array(batch_norm._running_mean)
|
||||
running_var = np.array(batch_norm._running_var)
|
||||
running_mean = np.array(batch_norm.running_mean)
|
||||
running_var = np.array(batch_norm.running_var)
|
||||
data = mx.random.normal((batch_size, h, w, num_features))
|
||||
|
||||
normalized_data = batch_norm(data)
|
||||
@ -467,8 +487,8 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
variances = np.var(np_data, axis=(0, 1, 2))
|
||||
running_mean = (1 - momentum) * running_mean + momentum * means
|
||||
running_var = (1 - momentum) * running_var + momentum * variances
|
||||
self.assertTrue(np.allclose(batch_norm._running_mean, running_mean, atol=1e-5))
|
||||
self.assertTrue(np.allclose(batch_norm._running_var, running_var, atol=1e-5))
|
||||
self.assertTrue(np.allclose(batch_norm.running_mean, running_mean, atol=1e-5))
|
||||
self.assertTrue(np.allclose(batch_norm.running_var, running_var, atol=1e-5))
|
||||
|
||||
def test_conv1d(self):
|
||||
N = 5
|
||||
|
Loading…
Reference in New Issue
Block a user