diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 15648e6e0..519e041c5 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -11,6 +11,7 @@ MLX was developed with contributions from the following individuals: - Juarez Bochi: Fixed bug in cross attention. - Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example. - Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot` and safetensor support +- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. # Third-Party Software diff --git a/docs/src/python/nn/layers.rst b/docs/src/python/nn/layers.rst index 7ead319fd..4b2107446 100644 --- a/docs/src/python/nn/layers.rst +++ b/docs/src/python/nn/layers.rst @@ -26,6 +26,7 @@ Layers LayerNorm RMSNorm GroupNorm + InstanceNorm Dropout Dropout2d Dropout3d diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index 29787a3cc..80d500c9f 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -46,7 +46,13 @@ from mlx.nn.layers.convolution import Conv1d, Conv2d from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d from mlx.nn.layers.embedding import Embedding from mlx.nn.layers.linear import Bilinear, Identity, Linear -from mlx.nn.layers.normalization import BatchNorm, GroupNorm, LayerNorm, RMSNorm +from mlx.nn.layers.normalization import ( + BatchNorm, + GroupNorm, + InstanceNorm, + LayerNorm, + RMSNorm, +) from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding from mlx.nn.layers.quantized import QuantizedLinear from mlx.nn.layers.transformer import ( diff --git a/python/mlx/nn/layers/normalization.py b/python/mlx/nn/layers/normalization.py index 9c77667e7..b2b60ccba 100644 --- a/python/mlx/nn/layers/normalization.py +++ b/python/mlx/nn/layers/normalization.py @@ -6,6 +6,66 @@ import mlx.core as mx from mlx.nn.layers.base import Module +class InstanceNorm(Module): + r"""Applies instance normalization [1] on the inputs. + + Computes + + .. math:: + + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta, + + where :math:`\gamma` and :math:`\beta` are learned per feature dimension + parameters initialized at 1 and 0 respectively. Both are of size :attr:`dims`, + if :attr:`affine` is ``True``. + + Args: + dims (int): The number of features of the input. + eps (float): A value added to the denominator for numerical stability. Default: ``1e-5``. + affine (bool): Default: ``False``. + + Shape: + - Input: :math:`(..., C)` where :math:`C` is equal to :attr:`dims`. + - Output: Same shape as the input. + + Examples: + >>> import mlx.core as mx + >>> import mlx.nn as nn + >>> x = mx.random.normal((8, 4, 4, 16)) + >>> inorm = nn.InstanceNorm(dims=16) + >>> output = inorm(x) + + References: + [1]: https://arxiv.org/abs/1607.08022 + """ + + def __init__( + self, + dims: int, + eps: float = 1e-5, + affine: bool = False, + ): + super().__init__() + if affine: + self.weight = mx.ones((dims,)) + self.bias = mx.zeros((dims,)) + self.dims = dims + self.eps = eps + + def _extra_repr(self): + return f"{self.dims}, eps={self.eps}, affine={'weight' in self}" + + def __call__(self, x: mx.array) -> mx.array: + reduction_axes = tuple(range(1, x.ndim - 1)) + # Compute stats + mean = mx.mean(x, axis=reduction_axes, keepdims=True) + var = mx.var(x, axis=reduction_axes, keepdims=True) + # Normalize + x = (x - mean) * mx.rsqrt(var + self.eps) + # Scale and shift if necessary + return (self.weight * x + self.bias) if "weight" in self else x + + class LayerNorm(Module): r"""Applies layer normalization [1] on the inputs. diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 8529f33a6..28e72a7e7 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -172,6 +172,224 @@ class TestLayers(mlx_tests.MLXTestCase): self.assertTrue(np.allclose(means, 3 * np.ones_like(means), atol=1e-6)) self.assertTrue(np.allclose(var, 4 * np.ones_like(var), atol=1e-6)) + def test_instance_norm(self): + # Test InstanceNorm1d + x = mx.array( + [ + [ + [-0.0119524, 1.1263, 2.02223], + [-0.500331, 0.517899, -1.21143], + [1.12958, -0.21413, -2.48738], + [1.39955, 0.891329, 1.63289], + ], + [ + [0.241417, -0.619157, -0.77484], + [-1.42512, 0.970817, -1.31352], + [2.739, -1.2506, 1.56844], + [-1.23175, 0.32756, 1.13969], + ], + ] + ) + inorm = nn.InstanceNorm(dims=3) + y = inorm(x) + expected_y = [ + [ + [-0.657082, 1.07593, 1.0712], + [-1.27879, -0.123074, -0.632505], + [0.796101, -1.56572, -1.30476], + [1.13978, 0.612862, 0.866067], + ], + [ + [0.0964426, -0.557906, -0.759885], + [-0.904772, 1.30444, -1.20013], + [1.59693, -1.29752, 1.15521], + [-0.7886, 0.550987, 0.804807], + ], + ] + self.assertTrue(x.shape == y.shape) + self.assertTrue(np.allclose(y, expected_y, atol=1e-5)) + # Test InstanceNorm2d + x = mx.array( + [ + [ + [ + [-0.458824, 0.483254, -0.58611], + [-0.447996, -0.176577, -0.622545], + [0.0486988, -0.0611224, 1.8845], + ], + [ + [1.13049, 0.345315, -0.926389], + [0.301795, 0.99207, -0.184927], + [-2.23876, -0.758631, -1.12639], + ], + [ + [0.0986325, -1.82973, -0.241765], + [-1.25257, 0.154442, -0.556204], + [-0.329399, -0.319107, 0.830584], + ], + ], + [ + [ + [1.04407, 0.073752, 0.407081], + [0.0800776, 1.2513, 1.20627], + [0.782321, -0.444367, 0.563132], + ], + [ + [0.671423, -1.21689, -1.88979], + [-0.110299, -1.42248, 1.17838], + [0.159905, 0.516452, -0.539121], + ], + [ + [0.810252, 1.50456, 1.08659], + [0.182597, 0.0576239, 0.973883], + [-0.0621687, 0.184253, 0.784216], + ], + ], + ] + ) + inorm = nn.InstanceNorm(dims=3) + y = inorm(x) + expected_y = [ + [ + [ + [-0.120422, 0.801503, -0.463983], + [-0.108465, -0.0608611, -0.504602], + [0.440008, 0.090032, 2.29032], + ], + [ + [1.63457, 0.621224, -0.843335], + [0.719488, 1.4665, -0.0167344], + [-2.08591, -0.821575, -1.0663], + ], + [ + [0.495147, -2.22145, -0.0800989], + [-0.996913, 0.371763, -0.430643], + [0.022495, -0.24714, 1.11538], + ], + ], + [ + [ + [1.5975, 0.0190292, -0.0123306], + [-0.776381, 1.28291, 0.817237], + [0.952927, -0.537076, 0.149652], + ], + [ + [0.679836, -1.36624, -2.39651], + [-1.24519, -1.5869, 0.788287], + [-0.579802, 0.494186, -0.994499], + ], + [ + [1.02171, 1.55474, 0.693008], + [-0.523922, 0.00171862, 0.576016], + [-1.12667, 0.137632, 0.37914], + ], + ], + ] + self.assertTrue(x.shape == y.shape) + self.assertTrue(np.allclose(y, expected_y, atol=1e-5)) + # # Test InstanceNorm3d + x = mx.array( + [ + [ + [ + [[0.777621, 0.528145, -1.56133], [-2.1722, 0.128192, 0.153862]], + [ + [-1.41317, 0.476288, -1.20411], + [0.284446, -0.649858, 0.152112], + ], + ], + [ + [[0.11, -0.12431, 1.18768], [-0.837743, 1.93502, 0.00236324]], + [ + [-2.40205, -1.25873, -2.04243], + [0.336682, -0.261986, 1.54289], + ], + ], + [ + [ + [0.789185, -1.63747, 0.67917], + [-1.42998, -1.73247, -0.402572], + ], + [ + [-0.459489, -2.15559, -0.249959], + [0.0298199, 0.10275, -0.821897], + ], + ], + ], + [ + [ + [ + [-2.12354, 0.643973, 0.72391], + [0.317797, -0.682916, 0.016364], + ], + [ + [-0.146628, -0.987925, 0.573199], + [0.0329215, 1.54086, 0.213092], + ], + ], + [ + [ + [-1.55784, 0.71179, -0.0678402], + [2.41031, -0.290786, 0.00449439], + ], + [ + [0.226341, 0.057712, -1.58342], + [0.265387, -0.742304, 1.28133], + ], + ], + [ + [ + [0.990317, -0.399875, -0.357647], + [0.475161, -1.10479, -1.07389], + ], + [ + [-1.37804, 1.40097, 0.141618], + [-0.501041, 0.0723374, -0.386141], + ], + ], + ], + ] + ) + inorm = nn.InstanceNorm(dims=3) + y = inorm(x) + expected_y = [ + [ + [ + [[1.23593, 0.821849, -1.30944], [-1.54739, 0.462867, 0.357126]], + [[-0.831204, 0.775304, -0.962338], [0.770588, -0.23548, 0.355425]], + ], + [ + [[0.605988, 0.236231, 1.36163], [-0.288258, 2.0846, 0.209922]], + [[-1.76427, -0.78198, -1.77689], [0.819875, 0.112659, 1.70677]], + ], + [ + [[1.24684, -1.12192, 0.867539], [-0.847068, -1.20719, -0.183531]], + [ + [0.0686449, -1.58697, -0.0352458], + [0.530334, 0.440032, -0.590967], + ], + ], + ], + [ + [ + [[-1.75315, 0.733967, 1.04349], [0.343736, -0.822472, 0.080661]], + [[-0.0551618, -1.18025, 0.838402], [0.0990544, 1.78602, 0.348368]], + ], + [ + [[-1.26726, 0.813517, -0.033924], [2.14101, -0.362504, 0.0645089]], + [[0.265184, 0.0462839, -2.09632], [0.298721, -0.892134, 1.80203]], + ], + [ + [[0.921369, -0.490465, -0.428293], [0.478897, -1.31732, -1.40296]], + [[-1.11283, 1.62192, 0.251107], [-0.35957, 0.0634394, -0.467067]], + ], + ], + ] + self.assertTrue(x.shape == y.shape) + self.assertTrue(np.allclose(y, expected_y, atol=1e-5)) + # Test repr + self.assertTrue(str(inorm) == "InstanceNorm(3, eps=1e-05, affine=False)") + def test_batch_norm(self): mx.random.seed(42) x = mx.random.normal((5, 4), dtype=mx.float32)