mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
implemented InstanceNorm (#244)
* implemented instancenorm * implemented vector_norm in cpp added linalg to mlx * implemented vector_norm python binding * renamed vector_norm to norm, implemented norm without provided ord * completed the implementation of the norm * added tests * removed unused import in linalg.cpp * updated python bindings * added some tests for python bindings * handling inf, -inf as numpy does, more extensive tests of compatibility with numpy * added better docs and examples * refactored mlx.linalg.norm bindings * reused existing util for implementation of linalg.norm * more tests * fixed a bug with no ord and axis provided * removed unused imports * some style and API consistency updates to linalg norm * remove unused includes * fix python tests * fixed a bug with frobenius norm of a complex-valued matrix * complex for vector too * addressed PR review comments * fixed import order in __init__ * expected values in instancenorm tests are simple lists * minor return expression style change * added InstanceNorm to docs * doc string nits * added myself to individual contributors --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
dff4a3833f
commit
c7edafb729
@ -11,6 +11,7 @@ MLX was developed with contributions from the following individuals:
|
||||
- Juarez Bochi: Fixed bug in cross attention.
|
||||
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
||||
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot` and safetensor support
|
||||
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer.
|
||||
|
||||
# Third-Party Software
|
||||
|
||||
|
@ -26,6 +26,7 @@ Layers
|
||||
LayerNorm
|
||||
RMSNorm
|
||||
GroupNorm
|
||||
InstanceNorm
|
||||
Dropout
|
||||
Dropout2d
|
||||
Dropout3d
|
||||
|
@ -46,7 +46,13 @@ from mlx.nn.layers.convolution import Conv1d, Conv2d
|
||||
from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d
|
||||
from mlx.nn.layers.embedding import Embedding
|
||||
from mlx.nn.layers.linear import Bilinear, Identity, Linear
|
||||
from mlx.nn.layers.normalization import BatchNorm, GroupNorm, LayerNorm, RMSNorm
|
||||
from mlx.nn.layers.normalization import (
|
||||
BatchNorm,
|
||||
GroupNorm,
|
||||
InstanceNorm,
|
||||
LayerNorm,
|
||||
RMSNorm,
|
||||
)
|
||||
from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding
|
||||
from mlx.nn.layers.quantized import QuantizedLinear
|
||||
from mlx.nn.layers.transformer import (
|
||||
|
@ -6,6 +6,66 @@ import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
|
||||
|
||||
class InstanceNorm(Module):
|
||||
r"""Applies instance normalization [1] on the inputs.
|
||||
|
||||
Computes
|
||||
|
||||
.. math::
|
||||
|
||||
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta,
|
||||
|
||||
where :math:`\gamma` and :math:`\beta` are learned per feature dimension
|
||||
parameters initialized at 1 and 0 respectively. Both are of size :attr:`dims`,
|
||||
if :attr:`affine` is ``True``.
|
||||
|
||||
Args:
|
||||
dims (int): The number of features of the input.
|
||||
eps (float): A value added to the denominator for numerical stability. Default: ``1e-5``.
|
||||
affine (bool): Default: ``False``.
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(..., C)` where :math:`C` is equal to :attr:`dims`.
|
||||
- Output: Same shape as the input.
|
||||
|
||||
Examples:
|
||||
>>> import mlx.core as mx
|
||||
>>> import mlx.nn as nn
|
||||
>>> x = mx.random.normal((8, 4, 4, 16))
|
||||
>>> inorm = nn.InstanceNorm(dims=16)
|
||||
>>> output = inorm(x)
|
||||
|
||||
References:
|
||||
[1]: https://arxiv.org/abs/1607.08022
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
eps: float = 1e-5,
|
||||
affine: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
if affine:
|
||||
self.weight = mx.ones((dims,))
|
||||
self.bias = mx.zeros((dims,))
|
||||
self.dims = dims
|
||||
self.eps = eps
|
||||
|
||||
def _extra_repr(self):
|
||||
return f"{self.dims}, eps={self.eps}, affine={'weight' in self}"
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
reduction_axes = tuple(range(1, x.ndim - 1))
|
||||
# Compute stats
|
||||
mean = mx.mean(x, axis=reduction_axes, keepdims=True)
|
||||
var = mx.var(x, axis=reduction_axes, keepdims=True)
|
||||
# Normalize
|
||||
x = (x - mean) * mx.rsqrt(var + self.eps)
|
||||
# Scale and shift if necessary
|
||||
return (self.weight * x + self.bias) if "weight" in self else x
|
||||
|
||||
|
||||
class LayerNorm(Module):
|
||||
r"""Applies layer normalization [1] on the inputs.
|
||||
|
||||
|
@ -172,6 +172,224 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(np.allclose(means, 3 * np.ones_like(means), atol=1e-6))
|
||||
self.assertTrue(np.allclose(var, 4 * np.ones_like(var), atol=1e-6))
|
||||
|
||||
def test_instance_norm(self):
|
||||
# Test InstanceNorm1d
|
||||
x = mx.array(
|
||||
[
|
||||
[
|
||||
[-0.0119524, 1.1263, 2.02223],
|
||||
[-0.500331, 0.517899, -1.21143],
|
||||
[1.12958, -0.21413, -2.48738],
|
||||
[1.39955, 0.891329, 1.63289],
|
||||
],
|
||||
[
|
||||
[0.241417, -0.619157, -0.77484],
|
||||
[-1.42512, 0.970817, -1.31352],
|
||||
[2.739, -1.2506, 1.56844],
|
||||
[-1.23175, 0.32756, 1.13969],
|
||||
],
|
||||
]
|
||||
)
|
||||
inorm = nn.InstanceNorm(dims=3)
|
||||
y = inorm(x)
|
||||
expected_y = [
|
||||
[
|
||||
[-0.657082, 1.07593, 1.0712],
|
||||
[-1.27879, -0.123074, -0.632505],
|
||||
[0.796101, -1.56572, -1.30476],
|
||||
[1.13978, 0.612862, 0.866067],
|
||||
],
|
||||
[
|
||||
[0.0964426, -0.557906, -0.759885],
|
||||
[-0.904772, 1.30444, -1.20013],
|
||||
[1.59693, -1.29752, 1.15521],
|
||||
[-0.7886, 0.550987, 0.804807],
|
||||
],
|
||||
]
|
||||
self.assertTrue(x.shape == y.shape)
|
||||
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
|
||||
# Test InstanceNorm2d
|
||||
x = mx.array(
|
||||
[
|
||||
[
|
||||
[
|
||||
[-0.458824, 0.483254, -0.58611],
|
||||
[-0.447996, -0.176577, -0.622545],
|
||||
[0.0486988, -0.0611224, 1.8845],
|
||||
],
|
||||
[
|
||||
[1.13049, 0.345315, -0.926389],
|
||||
[0.301795, 0.99207, -0.184927],
|
||||
[-2.23876, -0.758631, -1.12639],
|
||||
],
|
||||
[
|
||||
[0.0986325, -1.82973, -0.241765],
|
||||
[-1.25257, 0.154442, -0.556204],
|
||||
[-0.329399, -0.319107, 0.830584],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[1.04407, 0.073752, 0.407081],
|
||||
[0.0800776, 1.2513, 1.20627],
|
||||
[0.782321, -0.444367, 0.563132],
|
||||
],
|
||||
[
|
||||
[0.671423, -1.21689, -1.88979],
|
||||
[-0.110299, -1.42248, 1.17838],
|
||||
[0.159905, 0.516452, -0.539121],
|
||||
],
|
||||
[
|
||||
[0.810252, 1.50456, 1.08659],
|
||||
[0.182597, 0.0576239, 0.973883],
|
||||
[-0.0621687, 0.184253, 0.784216],
|
||||
],
|
||||
],
|
||||
]
|
||||
)
|
||||
inorm = nn.InstanceNorm(dims=3)
|
||||
y = inorm(x)
|
||||
expected_y = [
|
||||
[
|
||||
[
|
||||
[-0.120422, 0.801503, -0.463983],
|
||||
[-0.108465, -0.0608611, -0.504602],
|
||||
[0.440008, 0.090032, 2.29032],
|
||||
],
|
||||
[
|
||||
[1.63457, 0.621224, -0.843335],
|
||||
[0.719488, 1.4665, -0.0167344],
|
||||
[-2.08591, -0.821575, -1.0663],
|
||||
],
|
||||
[
|
||||
[0.495147, -2.22145, -0.0800989],
|
||||
[-0.996913, 0.371763, -0.430643],
|
||||
[0.022495, -0.24714, 1.11538],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[1.5975, 0.0190292, -0.0123306],
|
||||
[-0.776381, 1.28291, 0.817237],
|
||||
[0.952927, -0.537076, 0.149652],
|
||||
],
|
||||
[
|
||||
[0.679836, -1.36624, -2.39651],
|
||||
[-1.24519, -1.5869, 0.788287],
|
||||
[-0.579802, 0.494186, -0.994499],
|
||||
],
|
||||
[
|
||||
[1.02171, 1.55474, 0.693008],
|
||||
[-0.523922, 0.00171862, 0.576016],
|
||||
[-1.12667, 0.137632, 0.37914],
|
||||
],
|
||||
],
|
||||
]
|
||||
self.assertTrue(x.shape == y.shape)
|
||||
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
|
||||
# # Test InstanceNorm3d
|
||||
x = mx.array(
|
||||
[
|
||||
[
|
||||
[
|
||||
[[0.777621, 0.528145, -1.56133], [-2.1722, 0.128192, 0.153862]],
|
||||
[
|
||||
[-1.41317, 0.476288, -1.20411],
|
||||
[0.284446, -0.649858, 0.152112],
|
||||
],
|
||||
],
|
||||
[
|
||||
[[0.11, -0.12431, 1.18768], [-0.837743, 1.93502, 0.00236324]],
|
||||
[
|
||||
[-2.40205, -1.25873, -2.04243],
|
||||
[0.336682, -0.261986, 1.54289],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[0.789185, -1.63747, 0.67917],
|
||||
[-1.42998, -1.73247, -0.402572],
|
||||
],
|
||||
[
|
||||
[-0.459489, -2.15559, -0.249959],
|
||||
[0.0298199, 0.10275, -0.821897],
|
||||
],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[
|
||||
[-2.12354, 0.643973, 0.72391],
|
||||
[0.317797, -0.682916, 0.016364],
|
||||
],
|
||||
[
|
||||
[-0.146628, -0.987925, 0.573199],
|
||||
[0.0329215, 1.54086, 0.213092],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[-1.55784, 0.71179, -0.0678402],
|
||||
[2.41031, -0.290786, 0.00449439],
|
||||
],
|
||||
[
|
||||
[0.226341, 0.057712, -1.58342],
|
||||
[0.265387, -0.742304, 1.28133],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[0.990317, -0.399875, -0.357647],
|
||||
[0.475161, -1.10479, -1.07389],
|
||||
],
|
||||
[
|
||||
[-1.37804, 1.40097, 0.141618],
|
||||
[-0.501041, 0.0723374, -0.386141],
|
||||
],
|
||||
],
|
||||
],
|
||||
]
|
||||
)
|
||||
inorm = nn.InstanceNorm(dims=3)
|
||||
y = inorm(x)
|
||||
expected_y = [
|
||||
[
|
||||
[
|
||||
[[1.23593, 0.821849, -1.30944], [-1.54739, 0.462867, 0.357126]],
|
||||
[[-0.831204, 0.775304, -0.962338], [0.770588, -0.23548, 0.355425]],
|
||||
],
|
||||
[
|
||||
[[0.605988, 0.236231, 1.36163], [-0.288258, 2.0846, 0.209922]],
|
||||
[[-1.76427, -0.78198, -1.77689], [0.819875, 0.112659, 1.70677]],
|
||||
],
|
||||
[
|
||||
[[1.24684, -1.12192, 0.867539], [-0.847068, -1.20719, -0.183531]],
|
||||
[
|
||||
[0.0686449, -1.58697, -0.0352458],
|
||||
[0.530334, 0.440032, -0.590967],
|
||||
],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[[-1.75315, 0.733967, 1.04349], [0.343736, -0.822472, 0.080661]],
|
||||
[[-0.0551618, -1.18025, 0.838402], [0.0990544, 1.78602, 0.348368]],
|
||||
],
|
||||
[
|
||||
[[-1.26726, 0.813517, -0.033924], [2.14101, -0.362504, 0.0645089]],
|
||||
[[0.265184, 0.0462839, -2.09632], [0.298721, -0.892134, 1.80203]],
|
||||
],
|
||||
[
|
||||
[[0.921369, -0.490465, -0.428293], [0.478897, -1.31732, -1.40296]],
|
||||
[[-1.11283, 1.62192, 0.251107], [-0.35957, 0.0634394, -0.467067]],
|
||||
],
|
||||
],
|
||||
]
|
||||
self.assertTrue(x.shape == y.shape)
|
||||
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
|
||||
# Test repr
|
||||
self.assertTrue(str(inorm) == "InstanceNorm(3, eps=1e-05, affine=False)")
|
||||
|
||||
def test_batch_norm(self):
|
||||
mx.random.seed(42)
|
||||
x = mx.random.normal((5, 4), dtype=mx.float32)
|
||||
|
Loading…
Reference in New Issue
Block a user