implemented InstanceNorm (#244)

* implemented instancenorm

* implemented vector_norm in cpp

added linalg to mlx

* implemented vector_norm python binding

* renamed vector_norm to norm, implemented norm without provided ord

* completed the implementation of the norm

* added tests

* removed unused import in linalg.cpp

* updated python bindings

* added some tests for python bindings

* handling inf, -inf as numpy does, more extensive tests of compatibility with numpy

* added better docs and examples

* refactored mlx.linalg.norm bindings

* reused existing util for implementation of linalg.norm

* more tests

* fixed a bug with no ord and axis provided

* removed unused imports

* some style and API consistency updates to linalg norm

* remove unused includes

* fix python tests

* fixed a bug with frobenius norm of a complex-valued matrix

* complex for vector too

* addressed PR review comments

* fixed import order in __init__

* expected values in instancenorm tests are simple lists

* minor return expression style change

* added InstanceNorm to docs

* doc string nits

* added myself to individual contributors

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Gabrijel Boduljak 2024-01-03 21:21:15 +01:00 committed by GitHub
parent dff4a3833f
commit c7edafb729
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 287 additions and 1 deletions

View File

@ -11,6 +11,7 @@ MLX was developed with contributions from the following individuals:
- Juarez Bochi: Fixed bug in cross attention. - Juarez Bochi: Fixed bug in cross attention.
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example. - Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot` and safetensor support - Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot` and safetensor support
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer.
# Third-Party Software # Third-Party Software

View File

@ -26,6 +26,7 @@ Layers
LayerNorm LayerNorm
RMSNorm RMSNorm
GroupNorm GroupNorm
InstanceNorm
Dropout Dropout
Dropout2d Dropout2d
Dropout3d Dropout3d

View File

@ -46,7 +46,13 @@ from mlx.nn.layers.convolution import Conv1d, Conv2d
from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d
from mlx.nn.layers.embedding import Embedding from mlx.nn.layers.embedding import Embedding
from mlx.nn.layers.linear import Bilinear, Identity, Linear from mlx.nn.layers.linear import Bilinear, Identity, Linear
from mlx.nn.layers.normalization import BatchNorm, GroupNorm, LayerNorm, RMSNorm from mlx.nn.layers.normalization import (
BatchNorm,
GroupNorm,
InstanceNorm,
LayerNorm,
RMSNorm,
)
from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding
from mlx.nn.layers.quantized import QuantizedLinear from mlx.nn.layers.quantized import QuantizedLinear
from mlx.nn.layers.transformer import ( from mlx.nn.layers.transformer import (

View File

@ -6,6 +6,66 @@ import mlx.core as mx
from mlx.nn.layers.base import Module from mlx.nn.layers.base import Module
class InstanceNorm(Module):
r"""Applies instance normalization [1] on the inputs.
Computes
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta,
where :math:`\gamma` and :math:`\beta` are learned per feature dimension
parameters initialized at 1 and 0 respectively. Both are of size :attr:`dims`,
if :attr:`affine` is ``True``.
Args:
dims (int): The number of features of the input.
eps (float): A value added to the denominator for numerical stability. Default: ``1e-5``.
affine (bool): Default: ``False``.
Shape:
- Input: :math:`(..., C)` where :math:`C` is equal to :attr:`dims`.
- Output: Same shape as the input.
Examples:
>>> import mlx.core as mx
>>> import mlx.nn as nn
>>> x = mx.random.normal((8, 4, 4, 16))
>>> inorm = nn.InstanceNorm(dims=16)
>>> output = inorm(x)
References:
[1]: https://arxiv.org/abs/1607.08022
"""
def __init__(
self,
dims: int,
eps: float = 1e-5,
affine: bool = False,
):
super().__init__()
if affine:
self.weight = mx.ones((dims,))
self.bias = mx.zeros((dims,))
self.dims = dims
self.eps = eps
def _extra_repr(self):
return f"{self.dims}, eps={self.eps}, affine={'weight' in self}"
def __call__(self, x: mx.array) -> mx.array:
reduction_axes = tuple(range(1, x.ndim - 1))
# Compute stats
mean = mx.mean(x, axis=reduction_axes, keepdims=True)
var = mx.var(x, axis=reduction_axes, keepdims=True)
# Normalize
x = (x - mean) * mx.rsqrt(var + self.eps)
# Scale and shift if necessary
return (self.weight * x + self.bias) if "weight" in self else x
class LayerNorm(Module): class LayerNorm(Module):
r"""Applies layer normalization [1] on the inputs. r"""Applies layer normalization [1] on the inputs.

View File

@ -172,6 +172,224 @@ class TestLayers(mlx_tests.MLXTestCase):
self.assertTrue(np.allclose(means, 3 * np.ones_like(means), atol=1e-6)) self.assertTrue(np.allclose(means, 3 * np.ones_like(means), atol=1e-6))
self.assertTrue(np.allclose(var, 4 * np.ones_like(var), atol=1e-6)) self.assertTrue(np.allclose(var, 4 * np.ones_like(var), atol=1e-6))
def test_instance_norm(self):
# Test InstanceNorm1d
x = mx.array(
[
[
[-0.0119524, 1.1263, 2.02223],
[-0.500331, 0.517899, -1.21143],
[1.12958, -0.21413, -2.48738],
[1.39955, 0.891329, 1.63289],
],
[
[0.241417, -0.619157, -0.77484],
[-1.42512, 0.970817, -1.31352],
[2.739, -1.2506, 1.56844],
[-1.23175, 0.32756, 1.13969],
],
]
)
inorm = nn.InstanceNorm(dims=3)
y = inorm(x)
expected_y = [
[
[-0.657082, 1.07593, 1.0712],
[-1.27879, -0.123074, -0.632505],
[0.796101, -1.56572, -1.30476],
[1.13978, 0.612862, 0.866067],
],
[
[0.0964426, -0.557906, -0.759885],
[-0.904772, 1.30444, -1.20013],
[1.59693, -1.29752, 1.15521],
[-0.7886, 0.550987, 0.804807],
],
]
self.assertTrue(x.shape == y.shape)
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
# Test InstanceNorm2d
x = mx.array(
[
[
[
[-0.458824, 0.483254, -0.58611],
[-0.447996, -0.176577, -0.622545],
[0.0486988, -0.0611224, 1.8845],
],
[
[1.13049, 0.345315, -0.926389],
[0.301795, 0.99207, -0.184927],
[-2.23876, -0.758631, -1.12639],
],
[
[0.0986325, -1.82973, -0.241765],
[-1.25257, 0.154442, -0.556204],
[-0.329399, -0.319107, 0.830584],
],
],
[
[
[1.04407, 0.073752, 0.407081],
[0.0800776, 1.2513, 1.20627],
[0.782321, -0.444367, 0.563132],
],
[
[0.671423, -1.21689, -1.88979],
[-0.110299, -1.42248, 1.17838],
[0.159905, 0.516452, -0.539121],
],
[
[0.810252, 1.50456, 1.08659],
[0.182597, 0.0576239, 0.973883],
[-0.0621687, 0.184253, 0.784216],
],
],
]
)
inorm = nn.InstanceNorm(dims=3)
y = inorm(x)
expected_y = [
[
[
[-0.120422, 0.801503, -0.463983],
[-0.108465, -0.0608611, -0.504602],
[0.440008, 0.090032, 2.29032],
],
[
[1.63457, 0.621224, -0.843335],
[0.719488, 1.4665, -0.0167344],
[-2.08591, -0.821575, -1.0663],
],
[
[0.495147, -2.22145, -0.0800989],
[-0.996913, 0.371763, -0.430643],
[0.022495, -0.24714, 1.11538],
],
],
[
[
[1.5975, 0.0190292, -0.0123306],
[-0.776381, 1.28291, 0.817237],
[0.952927, -0.537076, 0.149652],
],
[
[0.679836, -1.36624, -2.39651],
[-1.24519, -1.5869, 0.788287],
[-0.579802, 0.494186, -0.994499],
],
[
[1.02171, 1.55474, 0.693008],
[-0.523922, 0.00171862, 0.576016],
[-1.12667, 0.137632, 0.37914],
],
],
]
self.assertTrue(x.shape == y.shape)
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
# # Test InstanceNorm3d
x = mx.array(
[
[
[
[[0.777621, 0.528145, -1.56133], [-2.1722, 0.128192, 0.153862]],
[
[-1.41317, 0.476288, -1.20411],
[0.284446, -0.649858, 0.152112],
],
],
[
[[0.11, -0.12431, 1.18768], [-0.837743, 1.93502, 0.00236324]],
[
[-2.40205, -1.25873, -2.04243],
[0.336682, -0.261986, 1.54289],
],
],
[
[
[0.789185, -1.63747, 0.67917],
[-1.42998, -1.73247, -0.402572],
],
[
[-0.459489, -2.15559, -0.249959],
[0.0298199, 0.10275, -0.821897],
],
],
],
[
[
[
[-2.12354, 0.643973, 0.72391],
[0.317797, -0.682916, 0.016364],
],
[
[-0.146628, -0.987925, 0.573199],
[0.0329215, 1.54086, 0.213092],
],
],
[
[
[-1.55784, 0.71179, -0.0678402],
[2.41031, -0.290786, 0.00449439],
],
[
[0.226341, 0.057712, -1.58342],
[0.265387, -0.742304, 1.28133],
],
],
[
[
[0.990317, -0.399875, -0.357647],
[0.475161, -1.10479, -1.07389],
],
[
[-1.37804, 1.40097, 0.141618],
[-0.501041, 0.0723374, -0.386141],
],
],
],
]
)
inorm = nn.InstanceNorm(dims=3)
y = inorm(x)
expected_y = [
[
[
[[1.23593, 0.821849, -1.30944], [-1.54739, 0.462867, 0.357126]],
[[-0.831204, 0.775304, -0.962338], [0.770588, -0.23548, 0.355425]],
],
[
[[0.605988, 0.236231, 1.36163], [-0.288258, 2.0846, 0.209922]],
[[-1.76427, -0.78198, -1.77689], [0.819875, 0.112659, 1.70677]],
],
[
[[1.24684, -1.12192, 0.867539], [-0.847068, -1.20719, -0.183531]],
[
[0.0686449, -1.58697, -0.0352458],
[0.530334, 0.440032, -0.590967],
],
],
],
[
[
[[-1.75315, 0.733967, 1.04349], [0.343736, -0.822472, 0.080661]],
[[-0.0551618, -1.18025, 0.838402], [0.0990544, 1.78602, 0.348368]],
],
[
[[-1.26726, 0.813517, -0.033924], [2.14101, -0.362504, 0.0645089]],
[[0.265184, 0.0462839, -2.09632], [0.298721, -0.892134, 1.80203]],
],
[
[[0.921369, -0.490465, -0.428293], [0.478897, -1.31732, -1.40296]],
[[-1.11283, 1.62192, 0.251107], [-0.35957, 0.0634394, -0.467067]],
],
],
]
self.assertTrue(x.shape == y.shape)
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
# Test repr
self.assertTrue(str(inorm) == "InstanceNorm(3, eps=1e-05, affine=False)")
def test_batch_norm(self): def test_batch_norm(self):
mx.random.seed(42) mx.random.seed(42)
x = mx.random.normal((5, 4), dtype=mx.float32) x = mx.random.normal((5, 4), dtype=mx.float32)