implemented instancenorm

This commit is contained in:
Gabrijel Boduljak 2023-12-21 04:23:09 +01:00
parent f24200db2c
commit 6c7bebf428
3 changed files with 266 additions and 1 deletions

View File

@ -36,7 +36,7 @@ from mlx.nn.layers.convolution import Conv1d, Conv2d
from mlx.nn.layers.dropout import Dropout
from mlx.nn.layers.embedding import Embedding
from mlx.nn.layers.linear import Linear
from mlx.nn.layers.normalization import GroupNorm, LayerNorm, RMSNorm
from mlx.nn.layers.normalization import GroupNorm, InstanceNorm, LayerNorm, RMSNorm
from mlx.nn.layers.positional_encoding import RoPE, SinusoidalPositionalEncoding
from mlx.nn.layers.transformer import (
MultiHeadAttention,

View File

@ -4,6 +4,80 @@ import mlx.core as mx
from mlx.nn.layers.base import Module
class InstanceNorm(Module):
r"""Applies instance normalization [1] on the inputs.
Computes
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta,
where :math:`\gamma` and :math:`\beta` are learned per feature dimension
parameters initialized at 1 and 0 respectively. Both are of size num_features,
if :attr:`affine` is ``True``.
[1]: https://arxiv.org/abs/1607.08022
Args:
num_features: number of features of the input
eps: a value added to the denominator for numerical stability. Default: 1e-5
momentum: the value used for the running_mean and running_var computation. Default: 0.1
affine: a boolean value that when set to ``True``, this module has
learnable affine parameters, initialized the same way as done for batch normalization.
Default: ``False``.
Shape:
- Input: :math:`(N, C, L)` or :math:`(N, C, H, W)` or :math:`(N, C, D, H, W)`
- Output: :math:`(N, C, L)` or :math:`(N, C, H, W)` or :math:`(N, C, D, H, W)` (same shape as input)
"""
def __init__(
self,
num_features: int,
eps: float = 1e-5,
affine: bool = False,
):
super().__init__()
if affine:
self.weight = mx.ones((num_features,))
self.bias = mx.zeros((num_features,))
self.eps = eps
self.affine = affine
self.param_shapes = {
3: (1, num_features, 1), # input shape: (B, C, L)
4: (1, num_features, 1, 1), # input shape: (B, C, H, W)
5: (1, num_features, 1, 1, 1), # input shape: (B, C, D, H, W)
}
self.reduction_axes = {
3: [2], # input shape: (B, C, L)
4: [2, 3], # input shape: (B, C, H, W)
5: [2, 3, 4], # input shape: (B, C, D, H, W)
}
def extra_repr(self):
return "{num_features}, eps={eps}, momentum={momentum}, affine={affine}".format(
**self.__dict__
)
def __call__(self, x: mx.array) -> mx.array:
if self.affine and self.weight.ndim != x.ndim:
if x.ndim not in self.reduction_axes:
raise ValueError("Unsupported number shape")
# Ensure parameters are reshaped for correct broadcasting
self.weight = mx.reshape(self.weight, self.param_shapes[x.ndim])
self.bias = mx.reshape(self.bias, self.param_shapes[x.ndim])
# Compute stats
mean = mx.mean(x, axis=self.reduction_axes[x.ndim], keepdims=True)
var = mx.var(x, axis=self.reduction_axes[x.ndim], keepdims=True)
# Normalize
normalized = (x - mean) * mx.rsqrt(var + self.eps)
# Scale and shift if necessary
if self.affine:
return self.weight * normalized + self.bias
else:
return normalized
class LayerNorm(Module):
r"""Applies layer normalization [1] on the inputs.

View File

@ -226,6 +226,197 @@ class TestNN(mlx_tests.MLXTestCase):
self.assertTrue(np.allclose(means, 3 * np.ones_like(means), atol=1e-6))
self.assertTrue(np.allclose(var, 4 * np.ones_like(var), atol=1e-6))
def test_instance_norm(self):
# Test InstanceNorm1d
x = mx.array(
[
[
[-0.0119524, -0.500331, 1.12958, 1.39955],
[1.1263, 0.517899, -0.21413, 0.891329],
[2.02223, -1.21143, -2.48738, 1.63289],
],
[
[0.241417, -1.42512, 2.739, -1.23175],
[-0.619157, 0.970817, -1.2506, 0.32756],
[-0.77484, -1.31352, 1.56844, 1.13969],
],
]
)
inorm = nn.InstanceNorm(num_features=3)
y = inorm(x)
expected_y = [
[
[-0.657082, -1.27879, 0.796097, 1.13978],
[1.07593, -0.123075, -1.56572, 0.61286],
[1.0712, -0.632503, -1.30476, 0.866066],
],
[
[0.0964433, -0.904773, 1.59693, -0.788599],
[-0.557908, 1.30444, -1.29751, 0.550987],
[-0.759886, -1.20013, 1.15521, 0.804804],
],
]
self.assertTrue(x.shape == y.shape)
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
# Test InstanceNorm2d
x = mx.array(
[
[
[
[-0.458824, -0.447996, 0.0486988],
[1.13049, 0.301795, -2.23876],
[0.0986325, -1.25257, -0.329399],
],
[
[0.483254, -0.176577, -0.0611224],
[0.345315, 0.99207, -0.758631],
[-1.82973, 0.154442, -0.319107],
],
[
[-0.58611, -0.622545, 1.8845],
[-0.926389, -0.184927, -1.12639],
[-0.241765, -0.556204, 0.830584],
],
],
[
[
[1.04407, 0.0800776, 0.782321],
[0.671423, -0.110299, 0.159905],
[0.810252, 0.182597, -0.0621687],
],
[
[0.073752, 1.2513, -0.444367],
[-1.21689, -1.42248, 0.516452],
[1.50456, 0.0576239, 0.184253],
],
[
[0.407081, 1.20627, 0.563132],
[-1.88979, 1.17838, -0.539121],
[1.08659, 0.973883, 0.784216],
],
],
]
)
inorm = nn.InstanceNorm(num_features=3)
y = inorm(x)
expected_y = [
[
[
[-0.120422, -0.108465, 0.440008],
[1.63457, 0.719488, -2.08591],
[0.495147, -0.996913, 0.0224944],
],
[
[0.801504, -0.0608616, 0.0900314],
[0.621224, 1.4665, -0.821576],
[-2.22144, 0.371763, -0.247141],
],
[
[-0.463984, -0.504602, 2.29032],
[-0.843336, -0.0167355, -1.0663],
[-0.0800997, -0.430644, 1.11538],
],
],
[
[
[1.59749, -0.776381, 0.95293],
[0.679838, -1.24519, -0.579803],
[1.02171, -0.523923, -1.12667],
],
[
[0.0190289, 1.28291, -0.537076],
[-1.36624, -1.5869, 0.494185],
[1.55474, 0.00171834, 0.137631],
],
[
[-0.012331, 0.817234, 0.149652],
[-2.39651, 0.78829, -0.994498],
[0.693007, 0.576016, 0.37914],
],
],
]
self.assertTrue(x.shape == y.shape)
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
# Test InstanceNorm3d
x = mx.array(
[
[
[
[[0.777621, -2.1722], [-1.41317, 0.284446]],
[[0.11, -0.837743], [-2.40205, 0.336682]],
[[0.789185, -1.42998], [-0.459489, 0.0298199]],
],
[
[[0.528145, 0.128192], [0.476288, -0.649858]],
[[-0.12431, 1.93502], [-1.25873, -0.261986]],
[[-1.63747, -1.73247], [-2.15559, 0.10275]],
],
[
[[-1.56133, 0.153862], [-1.20411, 0.152112]],
[[1.18768, 0.00236324], [-2.04243, 1.54289]],
[[0.67917, -0.402572], [-0.249959, -0.821897]],
],
],
[
[
[[-2.12354, 0.317797], [-0.146628, 0.0329215]],
[[-1.55784, 2.41031], [0.226341, 0.265387]],
[[0.990317, 0.475161], [-1.37804, -0.501041]],
],
[
[[0.643973, -0.682916], [-0.987925, 1.54086]],
[[0.71179, -0.290786], [0.057712, -0.742304]],
[[-0.399875, -1.10479], [1.40097, 0.0723374]],
],
[
[[0.72391, 0.016364], [0.573199, 0.213092]],
[[-0.0678402, 0.00449439], [-1.58342, 1.28133]],
[[-0.357647, -1.07389], [0.141618, -0.386141]],
],
],
]
)
inorm = nn.InstanceNorm(num_features=3)
y = inorm(x)
expected_y = [
[
[
[[1.23593, -1.54739], [-0.831204, 0.770588]],
[[0.605988, -0.288258], [-1.76427, 0.819875]],
[[1.24684, -0.847068], [0.0686449, 0.530334]],
],
[
[[0.821849, 0.462867], [0.775304, -0.23548]],
[[0.236231, 2.0846], [-0.78198, 0.112659]],
[[-1.12192, -1.20719], [-1.58697, 0.440032]],
],
[
[[-1.30944, 0.357126], [-0.962338, 0.355425]],
[[1.36163, 0.209922], [-1.77689, 1.70677]],
[[0.867539, -0.183531], [-0.0352458, -0.590967]],
],
],
[
[
[[-1.75315, 0.343736], [-0.0551618, 0.0990544]],
[[-1.26726, 2.14101], [0.265184, 0.298721]],
[[0.921369, 0.478897], [-1.11283, -0.35957]],
],
[
[[0.733967, -0.822472], [-1.18025, 1.78602]],
[[0.813517, -0.362504], [0.0462839, -0.892134]],
[[-0.490465, -1.31732], [1.62192, 0.0634394]],
],
[
[[1.04349, 0.080661], [0.838402, 0.348368]],
[[-0.033924, 0.0645089], [-2.09632, 1.80203]],
[[-0.428293, -1.40296], [0.251107, -0.467067]],
],
],
]
self.assertTrue(x.shape == y.shape)
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
def test_conv1d(self):
N = 5
L = 12