updated BN implementation to handle input shape as NLC and NWHC^^

This commit is contained in:
m0saan 2023-12-24 23:04:31 +01:00
parent 28009c9cdb
commit 9bf68814a4
2 changed files with 59 additions and 78 deletions

View File

@ -196,14 +196,15 @@ class BatchNorm(Module):
[1]: https://arxiv.org/abs/1502.03167
The input tensor shape is specified as (N, C) or (N, C, L), representing the batch size (N), the number of features or channels (C), and optionally, the sequence length (L). The output tensor maintains the same shape as the input, adhering to (N, C) or (N, C, L).
For three-dimensional tensors, the shape is denoted as (N, C, H, W), where N signifies the batch size, C represents the number of channels, H corresponds to the height, and W denotes the width.
The input tensor shape is specified as (N, C) or (N, L, C), representing the batch size (N), the number of features or channels (C), and optionally, the sequence length (L). The output tensor maintains the same shape as the input, adhering to (N, C) or (N, L, C).
For three-dimensional tensors, the shape is denoted as (N, H, W, C), where N signifies the batch size, C represents the number of channels, H corresponds to the height, and W denotes the width.
Args:
num_features (int): The feature dimension of the input to normalize over.
eps (float, optional): A small additive constant for numerical stability. Default is 1e-5.
momentum (float, optional): The momentum for updating the running mean and variance. Default is 0.1.
affine (bool, optional): If True, learn an affine transform to apply after the normalization. Default is True.
track_running_stats (bool, optional): If True, track the running mean and variance. Default is True.
Examples:
>>> import mlx.core as mx
@ -230,7 +231,6 @@ class BatchNorm(Module):
self.momentum = momentum
self.affine = affine
self.track_running_stats = track_running_stats
self._dims_expanded = False
if self.affine:
self.weight = mx.ones((num_features,))
@ -243,36 +243,6 @@ class BatchNorm(Module):
def _extra_repr(self):
return f"{self.num_features}, eps={self.eps}, momentum={self.momentum}, affine={'weight' in self}, track_running_stats={self.track_running_stats}"
def _check_and_expand_dims(self, x: mx.array):
"""
Check if the input is a 2D or 3D tensor and expand the weight, bias, running mean, and running variance accordingly.
Args:
x (mx.array): Input tensor.
"""
num_dims = len(x.shape)
dims_dict = {
2: ((1, self.num_features), (0,)),
3: ((1, 1, self.num_features), (0, 1)),
4: ((1, 1, 1, self.num_features), (0, 1, 2)),
}
if num_dims not in dims_dict:
raise ValueError(f"expected num_dims to be 2, 3, or 4 (got {num_dims})")
shape, self.reduction_axes = dims_dict[num_dims]
if self.affine:
self.weight = mx.expand_dims(self.weight, self.reduction_axes)
self.bias = mx.expand_dims(self.bias, self.reduction_axes)
if self.track_running_stats:
self.running_mean = mx.expand_dims(self.running_mean, self.reduction_axes)
self.running_var = mx.expand_dims(self.running_var, self.reduction_axes)
self._dims_expanded = True
def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]:
"""
Calculate the mean and variance of the input tensor.
@ -283,17 +253,19 @@ class BatchNorm(Module):
Returns:
tuple: Tuple containing mean and variance.
"""
means = mx.mean(x, axis=self.reduction_axes, keepdims=True)
var = mx.var(x, axis=self.reduction_axes, keepdims=True)
reduction_axes = (
(0,) if len(x.shape) == 2 else (0, 1) if len(x.shape) == 3 else (0, 1, 2)
)
means = mx.mean(x, axis=reduction_axes, keepdims=True)
var = mx.var(x, axis=reduction_axes, keepdims=True)
if self.track_running_stats and self.training:
self.running_mean = (
self._running_mean = (
1 - self.momentum
) * self.running_mean + self.momentum * means
self.running_var = (
) * self._running_mean + self.momentum * means
self._running_var = (
1 - self.momentum
) * self.running_var + self.momentum * var
) * self._running_var + self.momentum * var
return means, var
def __call__(self, x: mx.array) -> mx.array:
@ -307,12 +279,14 @@ class BatchNorm(Module):
mx.array: Output tensor.
"""
if not self._dims_expanded:
self._check_and_expand_dims(x)
if x.ndim not in [2, 3, 4]:
raise ValueError(
f"Expected input tensor to have 2, 3 or 4 dimensions, but got {x.ndim}"
)
if self.training or not self.track_running_stats:
means, var = self._calc_stats(x)
else:
means, var = self.running_mean, self.running_var
means, var = self._running_mean, self._running_var
x = (x - means) * mx.rsqrt(var + self.eps)
return (self.weight * x + self.bias) if "weight" in self else x

View File

@ -327,8 +327,8 @@ class TestNN(mlx_tests.MLXTestCase):
# Batch norm
bn = nn.BatchNorm(num_features=4, affine=True)
self.assertTrue(mx.allclose(bn.running_mean, mx.zeros_like(bn.running_mean)))
self.assertTrue(mx.allclose(bn.running_var, mx.ones_like(bn.running_var)))
self.assertTrue(mx.allclose(bn._running_mean, mx.zeros_like(bn._running_mean)))
self.assertTrue(mx.allclose(bn._running_var, mx.ones_like(bn._running_var)))
y = bn(x)
expected_y = mx.array(
[
@ -343,8 +343,8 @@ class TestNN(mlx_tests.MLXTestCase):
expected_var = mx.array([0.928435, 1.00455, 1.04117, 0.94258])
self.assertTrue(x.shape == y.shape)
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
self.assertTrue(np.allclose(bn.running_mean, expected_mean, atol=1e-5))
self.assertTrue(np.allclose(bn.running_var, expected_var, atol=1e-5))
self.assertTrue(np.allclose(bn._running_mean, expected_mean, atol=1e-5))
self.assertTrue(np.allclose(bn._running_var, expected_var, atol=1e-5))
# test eval mode
bn.eval()
@ -379,37 +379,44 @@ class TestNN(mlx_tests.MLXTestCase):
# test with 3D input
mx.random.seed(42)
x = mx.random.normal((2, 4, 3), dtype=mx.float32)
N = 2
L = 4
C = 5
x = mx.random.normal((N, L, C), dtype=mx.float32)
# Batch norm
bn = nn.BatchNorm(num_features=4, affine=True)
self.assertTrue(mx.allclose(bn.running_mean, mx.zeros_like(bn.running_mean)))
self.assertTrue(mx.allclose(bn.running_var, mx.ones_like(bn.running_var)))
bn = nn.BatchNorm(num_features=C, affine=True)
self.assertTrue(mx.allclose(bn._running_mean, mx.zeros_like(bn._running_mean)))
self.assertTrue(mx.allclose(bn._running_var, mx.ones_like(bn._running_var)))
y = bn(x)
self.assertTrue(x.shape == y.shape)
expected_y = mx.array(
[
[
[-0.285056, -0.657241, 0.584881],
[1.079424, 0.795527, 0.163417],
[-0.351929, 0.669030, 1.713490],
[-0.679080, -1.467115, 1.077580],
[-0.335754, 0.342054, 1.02653, 0.628588, -1.63899],
[1.92092, 0.432319, 0.343043, 1.95489, 1.0696],
[-0.853748, 1.3661, 0.868569, 0.0199196, -0.887284],
[0.459206, -0.684822, -0.706354, -0.271531, 0.566341],
],
[
[-0.091968, -1.362007, 1.811391],
[-1.654407, -1.017945, 0.633983],
[-1.309168, 0.148356, -0.869779],
[-0.742132, 1.037774, 0.772974],
[-0.921179, 0.684951, -0.77466, -0.490372, -0.247032],
[1.10839, -2.13179, 0.628924, -1.62639, -0.539708],
[-0.348943, 0.412194, -2.03818, 0.524972, 1.64568],
[-1.02889, -0.421, 0.652127, -0.740079, 0.0313996],
],
]
)
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
expected_mean = mx.array(
[[[0.0362097], [0.0360611], [0.0166926], [-0.0111884]]]
[[[0.00207845, -5.3259e-05, 0.04755, -0.0697296, 0.0236228]]]
)
expected_var = mx.array([[[1.07218], [0.992639], [1.01724], [1.16217]]])
self.assertTrue(np.allclose(bn.running_mean, expected_mean, atol=1e-5))
self.assertTrue(np.allclose(bn.running_var, expected_var, atol=1e-5))
expected_var = mx.array([[[0.968415, 1.05322, 0.96913, 0.932305, 0.967224]]])
self.assertTrue(np.allclose(bn._running_mean, expected_mean, atol=1e-5))
self.assertTrue(np.allclose(bn._running_var, expected_var, atol=1e-5))
x = mx.random.normal((N, L, C, L, C), dtype=mx.float32)
with self.assertRaises(ValueError):
y = bn(x)
def test_batch_norm_stats(self):
batch_size = 4
@ -423,8 +430,8 @@ class TestNN(mlx_tests.MLXTestCase):
batch_norm = nn.BatchNorm(num_features)
batch_norm.train()
running_mean = np.array(batch_norm.running_mean.tolist())
running_var = np.array(batch_norm.running_var.tolist())
running_mean = np.array(batch_norm._running_mean.tolist())
running_var = np.array(batch_norm._running_var.tolist())
data = mx.random.normal((batch_size * num_features,)).reshape(
(batch_size, num_features)
@ -436,33 +443,33 @@ class TestNN(mlx_tests.MLXTestCase):
variances = np.var(data.tolist(), axis=0)
running_mean = (1 - momentum) * running_mean + momentum * means
running_var = (1 - momentum) * running_var + momentum * variances
assert np.allclose(batch_norm.running_mean, running_mean, atol=1e-5)
assert np.allclose(batch_norm.running_var, running_var, atol=1e-5)
assert np.allclose(batch_norm._running_mean, running_mean, atol=1e-5)
assert np.allclose(batch_norm._running_var, running_var, atol=1e-5)
data = normalized_data
batch_norm = nn.BatchNorm(num_channels)
batch_norm.train()
running_mean = np.array(batch_norm.running_mean.tolist()).reshape(
1, num_channels, 1, 1
running_mean = np.array(batch_norm._running_mean.tolist()).reshape(
1, 1, 1, num_channels
)
running_var = np.array(batch_norm.running_var.tolist()).reshape(
1, num_channels, 1, 1
running_var = np.array(batch_norm._running_var.tolist()).reshape(
1, 1, 1, num_channels
)
data = mx.random.normal((batch_size, num_channels, h, w))
data = mx.random.normal((batch_size, h, w, num_channels))
for _ in range(num_iterations):
normalized_data = batch_norm(data)
means = np.mean(data.tolist(), axis=(0, 2, 3)).reshape(
1, num_channels, 1, 1
means = np.mean(data.tolist(), axis=(0, 1, 2)).reshape(
1, 1, 1, num_channels
)
variances = np.var(data.tolist(), axis=(0, 2, 3)).reshape(
1, num_channels, 1, 1
variances = np.var(data.tolist(), axis=(0, 1, 2)).reshape(
1, 1, 1, num_channels
)
running_mean = (1 - momentum) * running_mean + momentum * means
running_var = (1 - momentum) * running_var + momentum * variances
assert np.allclose(batch_norm.running_mean, running_mean, atol=1e-5)
assert np.allclose(batch_norm.running_var, running_var, atol=1e-5)
assert np.allclose(batch_norm._running_mean, running_mean, atol=1e-5)
assert np.allclose(batch_norm._running_var, running_var, atol=1e-5)
data = normalized_data
def test_conv1d(self):