implemented InstanceNorm (#244)

* implemented instancenorm

* implemented vector_norm in cpp

added linalg to mlx

* implemented vector_norm python binding

* renamed vector_norm to norm, implemented norm without provided ord

* completed the implementation of the norm

* added tests

* removed unused import in linalg.cpp

* updated python bindings

* added some tests for python bindings

* handling inf, -inf as numpy does, more extensive tests of compatibility with numpy

* added better docs and examples

* refactored mlx.linalg.norm bindings

* reused existing util for implementation of linalg.norm

* more tests

* fixed a bug with no ord and axis provided

* removed unused imports

* some style and API consistency updates to linalg norm

* remove unused includes

* fix python tests

* fixed a bug with frobenius norm of a complex-valued matrix

* complex for vector too

* addressed PR review comments

* fixed import order in __init__

* expected values in instancenorm tests are simple lists

* minor return expression style change

* added InstanceNorm to docs

* doc string nits

* added myself to individual contributors

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Gabrijel Boduljak 2024-01-03 21:21:15 +01:00 committed by GitHub
parent dff4a3833f
commit c7edafb729
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 287 additions and 1 deletions

View File

@ -11,6 +11,7 @@ MLX was developed with contributions from the following individuals:
- Juarez Bochi: Fixed bug in cross attention.
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot` and safetensor support
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer.
# Third-Party Software

View File

@ -26,6 +26,7 @@ Layers
LayerNorm
RMSNorm
GroupNorm
InstanceNorm
Dropout
Dropout2d
Dropout3d

View File

@ -46,7 +46,13 @@ from mlx.nn.layers.convolution import Conv1d, Conv2d
from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d
from mlx.nn.layers.embedding import Embedding
from mlx.nn.layers.linear import Bilinear, Identity, Linear
from mlx.nn.layers.normalization import BatchNorm, GroupNorm, LayerNorm, RMSNorm
from mlx.nn.layers.normalization import (
BatchNorm,
GroupNorm,
InstanceNorm,
LayerNorm,
RMSNorm,
)
from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding
from mlx.nn.layers.quantized import QuantizedLinear
from mlx.nn.layers.transformer import (

View File

@ -6,6 +6,66 @@ import mlx.core as mx
from mlx.nn.layers.base import Module
class InstanceNorm(Module):
r"""Applies instance normalization [1] on the inputs.
Computes
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta,
where :math:`\gamma` and :math:`\beta` are learned per feature dimension
parameters initialized at 1 and 0 respectively. Both are of size :attr:`dims`,
if :attr:`affine` is ``True``.
Args:
dims (int): The number of features of the input.
eps (float): A value added to the denominator for numerical stability. Default: ``1e-5``.
affine (bool): Default: ``False``.
Shape:
- Input: :math:`(..., C)` where :math:`C` is equal to :attr:`dims`.
- Output: Same shape as the input.
Examples:
>>> import mlx.core as mx
>>> import mlx.nn as nn
>>> x = mx.random.normal((8, 4, 4, 16))
>>> inorm = nn.InstanceNorm(dims=16)
>>> output = inorm(x)
References:
[1]: https://arxiv.org/abs/1607.08022
"""
def __init__(
self,
dims: int,
eps: float = 1e-5,
affine: bool = False,
):
super().__init__()
if affine:
self.weight = mx.ones((dims,))
self.bias = mx.zeros((dims,))
self.dims = dims
self.eps = eps
def _extra_repr(self):
return f"{self.dims}, eps={self.eps}, affine={'weight' in self}"
def __call__(self, x: mx.array) -> mx.array:
reduction_axes = tuple(range(1, x.ndim - 1))
# Compute stats
mean = mx.mean(x, axis=reduction_axes, keepdims=True)
var = mx.var(x, axis=reduction_axes, keepdims=True)
# Normalize
x = (x - mean) * mx.rsqrt(var + self.eps)
# Scale and shift if necessary
return (self.weight * x + self.bias) if "weight" in self else x
class LayerNorm(Module):
r"""Applies layer normalization [1] on the inputs.

View File

@ -172,6 +172,224 @@ class TestLayers(mlx_tests.MLXTestCase):
self.assertTrue(np.allclose(means, 3 * np.ones_like(means), atol=1e-6))
self.assertTrue(np.allclose(var, 4 * np.ones_like(var), atol=1e-6))
def test_instance_norm(self):
# Test InstanceNorm1d
x = mx.array(
[
[
[-0.0119524, 1.1263, 2.02223],
[-0.500331, 0.517899, -1.21143],
[1.12958, -0.21413, -2.48738],
[1.39955, 0.891329, 1.63289],
],
[
[0.241417, -0.619157, -0.77484],
[-1.42512, 0.970817, -1.31352],
[2.739, -1.2506, 1.56844],
[-1.23175, 0.32756, 1.13969],
],
]
)
inorm = nn.InstanceNorm(dims=3)
y = inorm(x)
expected_y = [
[
[-0.657082, 1.07593, 1.0712],
[-1.27879, -0.123074, -0.632505],
[0.796101, -1.56572, -1.30476],
[1.13978, 0.612862, 0.866067],
],
[
[0.0964426, -0.557906, -0.759885],
[-0.904772, 1.30444, -1.20013],
[1.59693, -1.29752, 1.15521],
[-0.7886, 0.550987, 0.804807],
],
]
self.assertTrue(x.shape == y.shape)
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
# Test InstanceNorm2d
x = mx.array(
[
[
[
[-0.458824, 0.483254, -0.58611],
[-0.447996, -0.176577, -0.622545],
[0.0486988, -0.0611224, 1.8845],
],
[
[1.13049, 0.345315, -0.926389],
[0.301795, 0.99207, -0.184927],
[-2.23876, -0.758631, -1.12639],
],
[
[0.0986325, -1.82973, -0.241765],
[-1.25257, 0.154442, -0.556204],
[-0.329399, -0.319107, 0.830584],
],
],
[
[
[1.04407, 0.073752, 0.407081],
[0.0800776, 1.2513, 1.20627],
[0.782321, -0.444367, 0.563132],
],
[
[0.671423, -1.21689, -1.88979],
[-0.110299, -1.42248, 1.17838],
[0.159905, 0.516452, -0.539121],
],
[
[0.810252, 1.50456, 1.08659],
[0.182597, 0.0576239, 0.973883],
[-0.0621687, 0.184253, 0.784216],
],
],
]
)
inorm = nn.InstanceNorm(dims=3)
y = inorm(x)
expected_y = [
[
[
[-0.120422, 0.801503, -0.463983],
[-0.108465, -0.0608611, -0.504602],
[0.440008, 0.090032, 2.29032],
],
[
[1.63457, 0.621224, -0.843335],
[0.719488, 1.4665, -0.0167344],
[-2.08591, -0.821575, -1.0663],
],
[
[0.495147, -2.22145, -0.0800989],
[-0.996913, 0.371763, -0.430643],
[0.022495, -0.24714, 1.11538],
],
],
[
[
[1.5975, 0.0190292, -0.0123306],
[-0.776381, 1.28291, 0.817237],
[0.952927, -0.537076, 0.149652],
],
[
[0.679836, -1.36624, -2.39651],
[-1.24519, -1.5869, 0.788287],
[-0.579802, 0.494186, -0.994499],
],
[
[1.02171, 1.55474, 0.693008],
[-0.523922, 0.00171862, 0.576016],
[-1.12667, 0.137632, 0.37914],
],
],
]
self.assertTrue(x.shape == y.shape)
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
# # Test InstanceNorm3d
x = mx.array(
[
[
[
[[0.777621, 0.528145, -1.56133], [-2.1722, 0.128192, 0.153862]],
[
[-1.41317, 0.476288, -1.20411],
[0.284446, -0.649858, 0.152112],
],
],
[
[[0.11, -0.12431, 1.18768], [-0.837743, 1.93502, 0.00236324]],
[
[-2.40205, -1.25873, -2.04243],
[0.336682, -0.261986, 1.54289],
],
],
[
[
[0.789185, -1.63747, 0.67917],
[-1.42998, -1.73247, -0.402572],
],
[
[-0.459489, -2.15559, -0.249959],
[0.0298199, 0.10275, -0.821897],
],
],
],
[
[
[
[-2.12354, 0.643973, 0.72391],
[0.317797, -0.682916, 0.016364],
],
[
[-0.146628, -0.987925, 0.573199],
[0.0329215, 1.54086, 0.213092],
],
],
[
[
[-1.55784, 0.71179, -0.0678402],
[2.41031, -0.290786, 0.00449439],
],
[
[0.226341, 0.057712, -1.58342],
[0.265387, -0.742304, 1.28133],
],
],
[
[
[0.990317, -0.399875, -0.357647],
[0.475161, -1.10479, -1.07389],
],
[
[-1.37804, 1.40097, 0.141618],
[-0.501041, 0.0723374, -0.386141],
],
],
],
]
)
inorm = nn.InstanceNorm(dims=3)
y = inorm(x)
expected_y = [
[
[
[[1.23593, 0.821849, -1.30944], [-1.54739, 0.462867, 0.357126]],
[[-0.831204, 0.775304, -0.962338], [0.770588, -0.23548, 0.355425]],
],
[
[[0.605988, 0.236231, 1.36163], [-0.288258, 2.0846, 0.209922]],
[[-1.76427, -0.78198, -1.77689], [0.819875, 0.112659, 1.70677]],
],
[
[[1.24684, -1.12192, 0.867539], [-0.847068, -1.20719, -0.183531]],
[
[0.0686449, -1.58697, -0.0352458],
[0.530334, 0.440032, -0.590967],
],
],
],
[
[
[[-1.75315, 0.733967, 1.04349], [0.343736, -0.822472, 0.080661]],
[[-0.0551618, -1.18025, 0.838402], [0.0990544, 1.78602, 0.348368]],
],
[
[[-1.26726, 0.813517, -0.033924], [2.14101, -0.362504, 0.0645089]],
[[0.265184, 0.0462839, -2.09632], [0.298721, -0.892134, 1.80203]],
],
[
[[0.921369, -0.490465, -0.428293], [0.478897, -1.31732, -1.40296]],
[[-1.11283, 1.62192, 0.251107], [-0.35957, 0.0634394, -0.467067]],
],
],
]
self.assertTrue(x.shape == y.shape)
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
# Test repr
self.assertTrue(str(inorm) == "InstanceNorm(3, eps=1e-05, affine=False)")
def test_batch_norm(self):
mx.random.seed(42)
x = mx.random.normal((5, 4), dtype=mx.float32)