mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-20 10:27:41 +08:00
updated BN implementation to handle input shape as NLC and NWHC^^
This commit is contained in:
parent
28009c9cdb
commit
9bf68814a4
@ -196,14 +196,15 @@ class BatchNorm(Module):
|
||||
|
||||
[1]: https://arxiv.org/abs/1502.03167
|
||||
|
||||
The input tensor shape is specified as (N, C) or (N, C, L), representing the batch size (N), the number of features or channels (C), and optionally, the sequence length (L). The output tensor maintains the same shape as the input, adhering to (N, C) or (N, C, L).
|
||||
For three-dimensional tensors, the shape is denoted as (N, C, H, W), where N signifies the batch size, C represents the number of channels, H corresponds to the height, and W denotes the width.
|
||||
The input tensor shape is specified as (N, C) or (N, L, C), representing the batch size (N), the number of features or channels (C), and optionally, the sequence length (L). The output tensor maintains the same shape as the input, adhering to (N, C) or (N, L, C).
|
||||
For three-dimensional tensors, the shape is denoted as (N, H, W, C), where N signifies the batch size, C represents the number of channels, H corresponds to the height, and W denotes the width.
|
||||
|
||||
Args:
|
||||
num_features (int): The feature dimension of the input to normalize over.
|
||||
eps (float, optional): A small additive constant for numerical stability. Default is 1e-5.
|
||||
momentum (float, optional): The momentum for updating the running mean and variance. Default is 0.1.
|
||||
affine (bool, optional): If True, learn an affine transform to apply after the normalization. Default is True.
|
||||
track_running_stats (bool, optional): If True, track the running mean and variance. Default is True.
|
||||
|
||||
Examples:
|
||||
>>> import mlx.core as mx
|
||||
@ -230,7 +231,6 @@ class BatchNorm(Module):
|
||||
self.momentum = momentum
|
||||
self.affine = affine
|
||||
self.track_running_stats = track_running_stats
|
||||
self._dims_expanded = False
|
||||
|
||||
if self.affine:
|
||||
self.weight = mx.ones((num_features,))
|
||||
@ -243,36 +243,6 @@ class BatchNorm(Module):
|
||||
def _extra_repr(self):
|
||||
return f"{self.num_features}, eps={self.eps}, momentum={self.momentum}, affine={'weight' in self}, track_running_stats={self.track_running_stats}"
|
||||
|
||||
def _check_and_expand_dims(self, x: mx.array):
|
||||
"""
|
||||
Check if the input is a 2D or 3D tensor and expand the weight, bias, running mean, and running variance accordingly.
|
||||
|
||||
Args:
|
||||
x (mx.array): Input tensor.
|
||||
"""
|
||||
|
||||
num_dims = len(x.shape)
|
||||
dims_dict = {
|
||||
2: ((1, self.num_features), (0,)),
|
||||
3: ((1, 1, self.num_features), (0, 1)),
|
||||
4: ((1, 1, 1, self.num_features), (0, 1, 2)),
|
||||
}
|
||||
|
||||
if num_dims not in dims_dict:
|
||||
raise ValueError(f"expected num_dims to be 2, 3, or 4 (got {num_dims})")
|
||||
|
||||
shape, self.reduction_axes = dims_dict[num_dims]
|
||||
|
||||
if self.affine:
|
||||
self.weight = mx.expand_dims(self.weight, self.reduction_axes)
|
||||
self.bias = mx.expand_dims(self.bias, self.reduction_axes)
|
||||
|
||||
if self.track_running_stats:
|
||||
self.running_mean = mx.expand_dims(self.running_mean, self.reduction_axes)
|
||||
self.running_var = mx.expand_dims(self.running_var, self.reduction_axes)
|
||||
|
||||
self._dims_expanded = True
|
||||
|
||||
def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]:
|
||||
"""
|
||||
Calculate the mean and variance of the input tensor.
|
||||
@ -283,17 +253,19 @@ class BatchNorm(Module):
|
||||
Returns:
|
||||
tuple: Tuple containing mean and variance.
|
||||
"""
|
||||
|
||||
means = mx.mean(x, axis=self.reduction_axes, keepdims=True)
|
||||
var = mx.var(x, axis=self.reduction_axes, keepdims=True)
|
||||
reduction_axes = (
|
||||
(0,) if len(x.shape) == 2 else (0, 1) if len(x.shape) == 3 else (0, 1, 2)
|
||||
)
|
||||
means = mx.mean(x, axis=reduction_axes, keepdims=True)
|
||||
var = mx.var(x, axis=reduction_axes, keepdims=True)
|
||||
|
||||
if self.track_running_stats and self.training:
|
||||
self.running_mean = (
|
||||
self._running_mean = (
|
||||
1 - self.momentum
|
||||
) * self.running_mean + self.momentum * means
|
||||
self.running_var = (
|
||||
) * self._running_mean + self.momentum * means
|
||||
self._running_var = (
|
||||
1 - self.momentum
|
||||
) * self.running_var + self.momentum * var
|
||||
) * self._running_var + self.momentum * var
|
||||
return means, var
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
@ -307,12 +279,14 @@ class BatchNorm(Module):
|
||||
mx.array: Output tensor.
|
||||
"""
|
||||
|
||||
if not self._dims_expanded:
|
||||
self._check_and_expand_dims(x)
|
||||
if x.ndim not in [2, 3, 4]:
|
||||
raise ValueError(
|
||||
f"Expected input tensor to have 2, 3 or 4 dimensions, but got {x.ndim}"
|
||||
)
|
||||
|
||||
if self.training or not self.track_running_stats:
|
||||
means, var = self._calc_stats(x)
|
||||
else:
|
||||
means, var = self.running_mean, self.running_var
|
||||
means, var = self._running_mean, self._running_var
|
||||
x = (x - means) * mx.rsqrt(var + self.eps)
|
||||
return (self.weight * x + self.bias) if "weight" in self else x
|
||||
|
@ -327,8 +327,8 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
|
||||
# Batch norm
|
||||
bn = nn.BatchNorm(num_features=4, affine=True)
|
||||
self.assertTrue(mx.allclose(bn.running_mean, mx.zeros_like(bn.running_mean)))
|
||||
self.assertTrue(mx.allclose(bn.running_var, mx.ones_like(bn.running_var)))
|
||||
self.assertTrue(mx.allclose(bn._running_mean, mx.zeros_like(bn._running_mean)))
|
||||
self.assertTrue(mx.allclose(bn._running_var, mx.ones_like(bn._running_var)))
|
||||
y = bn(x)
|
||||
expected_y = mx.array(
|
||||
[
|
||||
@ -343,8 +343,8 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
expected_var = mx.array([0.928435, 1.00455, 1.04117, 0.94258])
|
||||
self.assertTrue(x.shape == y.shape)
|
||||
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
|
||||
self.assertTrue(np.allclose(bn.running_mean, expected_mean, atol=1e-5))
|
||||
self.assertTrue(np.allclose(bn.running_var, expected_var, atol=1e-5))
|
||||
self.assertTrue(np.allclose(bn._running_mean, expected_mean, atol=1e-5))
|
||||
self.assertTrue(np.allclose(bn._running_var, expected_var, atol=1e-5))
|
||||
|
||||
# test eval mode
|
||||
bn.eval()
|
||||
@ -379,37 +379,44 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
|
||||
# test with 3D input
|
||||
mx.random.seed(42)
|
||||
x = mx.random.normal((2, 4, 3), dtype=mx.float32)
|
||||
N = 2
|
||||
L = 4
|
||||
C = 5
|
||||
x = mx.random.normal((N, L, C), dtype=mx.float32)
|
||||
|
||||
# Batch norm
|
||||
bn = nn.BatchNorm(num_features=4, affine=True)
|
||||
self.assertTrue(mx.allclose(bn.running_mean, mx.zeros_like(bn.running_mean)))
|
||||
self.assertTrue(mx.allclose(bn.running_var, mx.ones_like(bn.running_var)))
|
||||
bn = nn.BatchNorm(num_features=C, affine=True)
|
||||
self.assertTrue(mx.allclose(bn._running_mean, mx.zeros_like(bn._running_mean)))
|
||||
self.assertTrue(mx.allclose(bn._running_var, mx.ones_like(bn._running_var)))
|
||||
y = bn(x)
|
||||
self.assertTrue(x.shape == y.shape)
|
||||
expected_y = mx.array(
|
||||
[
|
||||
[
|
||||
[-0.285056, -0.657241, 0.584881],
|
||||
[1.079424, 0.795527, 0.163417],
|
||||
[-0.351929, 0.669030, 1.713490],
|
||||
[-0.679080, -1.467115, 1.077580],
|
||||
[-0.335754, 0.342054, 1.02653, 0.628588, -1.63899],
|
||||
[1.92092, 0.432319, 0.343043, 1.95489, 1.0696],
|
||||
[-0.853748, 1.3661, 0.868569, 0.0199196, -0.887284],
|
||||
[0.459206, -0.684822, -0.706354, -0.271531, 0.566341],
|
||||
],
|
||||
[
|
||||
[-0.091968, -1.362007, 1.811391],
|
||||
[-1.654407, -1.017945, 0.633983],
|
||||
[-1.309168, 0.148356, -0.869779],
|
||||
[-0.742132, 1.037774, 0.772974],
|
||||
[-0.921179, 0.684951, -0.77466, -0.490372, -0.247032],
|
||||
[1.10839, -2.13179, 0.628924, -1.62639, -0.539708],
|
||||
[-0.348943, 0.412194, -2.03818, 0.524972, 1.64568],
|
||||
[-1.02889, -0.421, 0.652127, -0.740079, 0.0313996],
|
||||
],
|
||||
]
|
||||
)
|
||||
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
|
||||
expected_mean = mx.array(
|
||||
[[[0.0362097], [0.0360611], [0.0166926], [-0.0111884]]]
|
||||
[[[0.00207845, -5.3259e-05, 0.04755, -0.0697296, 0.0236228]]]
|
||||
)
|
||||
expected_var = mx.array([[[1.07218], [0.992639], [1.01724], [1.16217]]])
|
||||
self.assertTrue(np.allclose(bn.running_mean, expected_mean, atol=1e-5))
|
||||
self.assertTrue(np.allclose(bn.running_var, expected_var, atol=1e-5))
|
||||
expected_var = mx.array([[[0.968415, 1.05322, 0.96913, 0.932305, 0.967224]]])
|
||||
self.assertTrue(np.allclose(bn._running_mean, expected_mean, atol=1e-5))
|
||||
self.assertTrue(np.allclose(bn._running_var, expected_var, atol=1e-5))
|
||||
|
||||
x = mx.random.normal((N, L, C, L, C), dtype=mx.float32)
|
||||
with self.assertRaises(ValueError):
|
||||
y = bn(x)
|
||||
|
||||
def test_batch_norm_stats(self):
|
||||
batch_size = 4
|
||||
@ -423,8 +430,8 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
batch_norm = nn.BatchNorm(num_features)
|
||||
|
||||
batch_norm.train()
|
||||
running_mean = np.array(batch_norm.running_mean.tolist())
|
||||
running_var = np.array(batch_norm.running_var.tolist())
|
||||
running_mean = np.array(batch_norm._running_mean.tolist())
|
||||
running_var = np.array(batch_norm._running_var.tolist())
|
||||
|
||||
data = mx.random.normal((batch_size * num_features,)).reshape(
|
||||
(batch_size, num_features)
|
||||
@ -436,33 +443,33 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
variances = np.var(data.tolist(), axis=0)
|
||||
running_mean = (1 - momentum) * running_mean + momentum * means
|
||||
running_var = (1 - momentum) * running_var + momentum * variances
|
||||
assert np.allclose(batch_norm.running_mean, running_mean, atol=1e-5)
|
||||
assert np.allclose(batch_norm.running_var, running_var, atol=1e-5)
|
||||
assert np.allclose(batch_norm._running_mean, running_mean, atol=1e-5)
|
||||
assert np.allclose(batch_norm._running_var, running_var, atol=1e-5)
|
||||
data = normalized_data
|
||||
|
||||
batch_norm = nn.BatchNorm(num_channels)
|
||||
|
||||
batch_norm.train()
|
||||
running_mean = np.array(batch_norm.running_mean.tolist()).reshape(
|
||||
1, num_channels, 1, 1
|
||||
running_mean = np.array(batch_norm._running_mean.tolist()).reshape(
|
||||
1, 1, 1, num_channels
|
||||
)
|
||||
running_var = np.array(batch_norm.running_var.tolist()).reshape(
|
||||
1, num_channels, 1, 1
|
||||
running_var = np.array(batch_norm._running_var.tolist()).reshape(
|
||||
1, 1, 1, num_channels
|
||||
)
|
||||
data = mx.random.normal((batch_size, num_channels, h, w))
|
||||
data = mx.random.normal((batch_size, h, w, num_channels))
|
||||
|
||||
for _ in range(num_iterations):
|
||||
normalized_data = batch_norm(data)
|
||||
means = np.mean(data.tolist(), axis=(0, 2, 3)).reshape(
|
||||
1, num_channels, 1, 1
|
||||
means = np.mean(data.tolist(), axis=(0, 1, 2)).reshape(
|
||||
1, 1, 1, num_channels
|
||||
)
|
||||
variances = np.var(data.tolist(), axis=(0, 2, 3)).reshape(
|
||||
1, num_channels, 1, 1
|
||||
variances = np.var(data.tolist(), axis=(0, 1, 2)).reshape(
|
||||
1, 1, 1, num_channels
|
||||
)
|
||||
running_mean = (1 - momentum) * running_mean + momentum * means
|
||||
running_var = (1 - momentum) * running_var + momentum * variances
|
||||
assert np.allclose(batch_norm.running_mean, running_mean, atol=1e-5)
|
||||
assert np.allclose(batch_norm.running_var, running_var, atol=1e-5)
|
||||
assert np.allclose(batch_norm._running_mean, running_mean, atol=1e-5)
|
||||
assert np.allclose(batch_norm._running_var, running_var, atol=1e-5)
|
||||
data = normalized_data
|
||||
|
||||
def test_conv1d(self):
|
||||
|
Loading…
Reference in New Issue
Block a user