From 6c7bebf428c7e18a76b621a647a5e02b25fe1647 Mon Sep 17 00:00:00 2001 From: Gabrijel Boduljak Date: Thu, 21 Dec 2023 04:23:09 +0100 Subject: [PATCH] implemented instancenorm --- python/mlx/nn/layers/__init__.py | 2 +- python/mlx/nn/layers/normalization.py | 74 ++++++++++ python/tests/test_nn.py | 191 ++++++++++++++++++++++++++ 3 files changed, 266 insertions(+), 1 deletion(-) diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index 04557843b..76f7ee981 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -36,7 +36,7 @@ from mlx.nn.layers.convolution import Conv1d, Conv2d from mlx.nn.layers.dropout import Dropout from mlx.nn.layers.embedding import Embedding from mlx.nn.layers.linear import Linear -from mlx.nn.layers.normalization import GroupNorm, LayerNorm, RMSNorm +from mlx.nn.layers.normalization import GroupNorm, InstanceNorm, LayerNorm, RMSNorm from mlx.nn.layers.positional_encoding import RoPE, SinusoidalPositionalEncoding from mlx.nn.layers.transformer import ( MultiHeadAttention, diff --git a/python/mlx/nn/layers/normalization.py b/python/mlx/nn/layers/normalization.py index 6de377cda..f87eddbc6 100644 --- a/python/mlx/nn/layers/normalization.py +++ b/python/mlx/nn/layers/normalization.py @@ -4,6 +4,80 @@ import mlx.core as mx from mlx.nn.layers.base import Module +class InstanceNorm(Module): + r"""Applies instance normalization [1] on the inputs. + + Computes + + .. math:: + + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta, + + where :math:`\gamma` and :math:`\beta` are learned per feature dimension + parameters initialized at 1 and 0 respectively. Both are of size num_features, + if :attr:`affine` is ``True``. + + [1]: https://arxiv.org/abs/1607.08022 + + Args: + num_features: number of features of the input + eps: a value added to the denominator for numerical stability. Default: 1e-5 + momentum: the value used for the running_mean and running_var computation. Default: 0.1 + affine: a boolean value that when set to ``True``, this module has + learnable affine parameters, initialized the same way as done for batch normalization. + Default: ``False``. + Shape: + - Input: :math:`(N, C, L)` or :math:`(N, C, H, W)` or :math:`(N, C, D, H, W)` + - Output: :math:`(N, C, L)` or :math:`(N, C, H, W)` or :math:`(N, C, D, H, W)` (same shape as input) + """ + + def __init__( + self, + num_features: int, + eps: float = 1e-5, + affine: bool = False, + ): + super().__init__() + if affine: + self.weight = mx.ones((num_features,)) + self.bias = mx.zeros((num_features,)) + self.eps = eps + self.affine = affine + self.param_shapes = { + 3: (1, num_features, 1), # input shape: (B, C, L) + 4: (1, num_features, 1, 1), # input shape: (B, C, H, W) + 5: (1, num_features, 1, 1, 1), # input shape: (B, C, D, H, W) + } + self.reduction_axes = { + 3: [2], # input shape: (B, C, L) + 4: [2, 3], # input shape: (B, C, H, W) + 5: [2, 3, 4], # input shape: (B, C, D, H, W) + } + + def extra_repr(self): + return "{num_features}, eps={eps}, momentum={momentum}, affine={affine}".format( + **self.__dict__ + ) + + def __call__(self, x: mx.array) -> mx.array: + if self.affine and self.weight.ndim != x.ndim: + if x.ndim not in self.reduction_axes: + raise ValueError("Unsupported number shape") + # Ensure parameters are reshaped for correct broadcasting + self.weight = mx.reshape(self.weight, self.param_shapes[x.ndim]) + self.bias = mx.reshape(self.bias, self.param_shapes[x.ndim]) + # Compute stats + mean = mx.mean(x, axis=self.reduction_axes[x.ndim], keepdims=True) + var = mx.var(x, axis=self.reduction_axes[x.ndim], keepdims=True) + # Normalize + normalized = (x - mean) * mx.rsqrt(var + self.eps) + # Scale and shift if necessary + if self.affine: + return self.weight * normalized + self.bias + else: + return normalized + + class LayerNorm(Module): r"""Applies layer normalization [1] on the inputs. diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index f5597474d..faeb63ec8 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -226,6 +226,197 @@ class TestNN(mlx_tests.MLXTestCase): self.assertTrue(np.allclose(means, 3 * np.ones_like(means), atol=1e-6)) self.assertTrue(np.allclose(var, 4 * np.ones_like(var), atol=1e-6)) + def test_instance_norm(self): + # Test InstanceNorm1d + x = mx.array( + [ + [ + [-0.0119524, -0.500331, 1.12958, 1.39955], + [1.1263, 0.517899, -0.21413, 0.891329], + [2.02223, -1.21143, -2.48738, 1.63289], + ], + [ + [0.241417, -1.42512, 2.739, -1.23175], + [-0.619157, 0.970817, -1.2506, 0.32756], + [-0.77484, -1.31352, 1.56844, 1.13969], + ], + ] + ) + inorm = nn.InstanceNorm(num_features=3) + y = inorm(x) + expected_y = [ + [ + [-0.657082, -1.27879, 0.796097, 1.13978], + [1.07593, -0.123075, -1.56572, 0.61286], + [1.0712, -0.632503, -1.30476, 0.866066], + ], + [ + [0.0964433, -0.904773, 1.59693, -0.788599], + [-0.557908, 1.30444, -1.29751, 0.550987], + [-0.759886, -1.20013, 1.15521, 0.804804], + ], + ] + self.assertTrue(x.shape == y.shape) + self.assertTrue(np.allclose(y, expected_y, atol=1e-5)) + # Test InstanceNorm2d + x = mx.array( + [ + [ + [ + [-0.458824, -0.447996, 0.0486988], + [1.13049, 0.301795, -2.23876], + [0.0986325, -1.25257, -0.329399], + ], + [ + [0.483254, -0.176577, -0.0611224], + [0.345315, 0.99207, -0.758631], + [-1.82973, 0.154442, -0.319107], + ], + [ + [-0.58611, -0.622545, 1.8845], + [-0.926389, -0.184927, -1.12639], + [-0.241765, -0.556204, 0.830584], + ], + ], + [ + [ + [1.04407, 0.0800776, 0.782321], + [0.671423, -0.110299, 0.159905], + [0.810252, 0.182597, -0.0621687], + ], + [ + [0.073752, 1.2513, -0.444367], + [-1.21689, -1.42248, 0.516452], + [1.50456, 0.0576239, 0.184253], + ], + [ + [0.407081, 1.20627, 0.563132], + [-1.88979, 1.17838, -0.539121], + [1.08659, 0.973883, 0.784216], + ], + ], + ] + ) + inorm = nn.InstanceNorm(num_features=3) + y = inorm(x) + expected_y = [ + [ + [ + [-0.120422, -0.108465, 0.440008], + [1.63457, 0.719488, -2.08591], + [0.495147, -0.996913, 0.0224944], + ], + [ + [0.801504, -0.0608616, 0.0900314], + [0.621224, 1.4665, -0.821576], + [-2.22144, 0.371763, -0.247141], + ], + [ + [-0.463984, -0.504602, 2.29032], + [-0.843336, -0.0167355, -1.0663], + [-0.0800997, -0.430644, 1.11538], + ], + ], + [ + [ + [1.59749, -0.776381, 0.95293], + [0.679838, -1.24519, -0.579803], + [1.02171, -0.523923, -1.12667], + ], + [ + [0.0190289, 1.28291, -0.537076], + [-1.36624, -1.5869, 0.494185], + [1.55474, 0.00171834, 0.137631], + ], + [ + [-0.012331, 0.817234, 0.149652], + [-2.39651, 0.78829, -0.994498], + [0.693007, 0.576016, 0.37914], + ], + ], + ] + self.assertTrue(x.shape == y.shape) + self.assertTrue(np.allclose(y, expected_y, atol=1e-5)) + # Test InstanceNorm3d + x = mx.array( + [ + [ + [ + [[0.777621, -2.1722], [-1.41317, 0.284446]], + [[0.11, -0.837743], [-2.40205, 0.336682]], + [[0.789185, -1.42998], [-0.459489, 0.0298199]], + ], + [ + [[0.528145, 0.128192], [0.476288, -0.649858]], + [[-0.12431, 1.93502], [-1.25873, -0.261986]], + [[-1.63747, -1.73247], [-2.15559, 0.10275]], + ], + [ + [[-1.56133, 0.153862], [-1.20411, 0.152112]], + [[1.18768, 0.00236324], [-2.04243, 1.54289]], + [[0.67917, -0.402572], [-0.249959, -0.821897]], + ], + ], + [ + [ + [[-2.12354, 0.317797], [-0.146628, 0.0329215]], + [[-1.55784, 2.41031], [0.226341, 0.265387]], + [[0.990317, 0.475161], [-1.37804, -0.501041]], + ], + [ + [[0.643973, -0.682916], [-0.987925, 1.54086]], + [[0.71179, -0.290786], [0.057712, -0.742304]], + [[-0.399875, -1.10479], [1.40097, 0.0723374]], + ], + [ + [[0.72391, 0.016364], [0.573199, 0.213092]], + [[-0.0678402, 0.00449439], [-1.58342, 1.28133]], + [[-0.357647, -1.07389], [0.141618, -0.386141]], + ], + ], + ] + ) + inorm = nn.InstanceNorm(num_features=3) + y = inorm(x) + expected_y = [ + [ + [ + [[1.23593, -1.54739], [-0.831204, 0.770588]], + [[0.605988, -0.288258], [-1.76427, 0.819875]], + [[1.24684, -0.847068], [0.0686449, 0.530334]], + ], + [ + [[0.821849, 0.462867], [0.775304, -0.23548]], + [[0.236231, 2.0846], [-0.78198, 0.112659]], + [[-1.12192, -1.20719], [-1.58697, 0.440032]], + ], + [ + [[-1.30944, 0.357126], [-0.962338, 0.355425]], + [[1.36163, 0.209922], [-1.77689, 1.70677]], + [[0.867539, -0.183531], [-0.0352458, -0.590967]], + ], + ], + [ + [ + [[-1.75315, 0.343736], [-0.0551618, 0.0990544]], + [[-1.26726, 2.14101], [0.265184, 0.298721]], + [[0.921369, 0.478897], [-1.11283, -0.35957]], + ], + [ + [[0.733967, -0.822472], [-1.18025, 1.78602]], + [[0.813517, -0.362504], [0.0462839, -0.892134]], + [[-0.490465, -1.31732], [1.62192, 0.0634394]], + ], + [ + [[1.04349, 0.080661], [0.838402, 0.348368]], + [[-0.033924, 0.0645089], [-2.09632, 1.80203]], + [[-0.428293, -1.40296], [0.251107, -0.467067]], + ], + ], + ] + self.assertTrue(x.shape == y.shape) + self.assertTrue(np.allclose(y, expected_y, atol=1e-5)) + def test_conv1d(self): N = 5 L = 12