implemented InstanceNorm (#244)

* implemented instancenorm

* implemented vector_norm in cpp

added linalg to mlx

* implemented vector_norm python binding

* renamed vector_norm to norm, implemented norm without provided ord

* completed the implementation of the norm

* added tests

* removed unused import in linalg.cpp

* updated python bindings

* added some tests for python bindings

* handling inf, -inf as numpy does, more extensive tests of compatibility with numpy

* added better docs and examples

* refactored mlx.linalg.norm bindings

* reused existing util for implementation of linalg.norm

* more tests

* fixed a bug with no ord and axis provided

* removed unused imports

* some style and API consistency updates to linalg norm

* remove unused includes

* fix python tests

* fixed a bug with frobenius norm of a complex-valued matrix

* complex for vector too

* addressed PR review comments

* fixed import order in __init__

* expected values in instancenorm tests are simple lists

* minor return expression style change

* added InstanceNorm to docs

* doc string nits

* added myself to individual contributors

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Gabrijel Boduljak
2024-01-03 21:21:15 +01:00
committed by GitHub
parent dff4a3833f
commit c7edafb729
5 changed files with 287 additions and 1 deletions

View File

@@ -46,7 +46,13 @@ from mlx.nn.layers.convolution import Conv1d, Conv2d
from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d
from mlx.nn.layers.embedding import Embedding
from mlx.nn.layers.linear import Bilinear, Identity, Linear
from mlx.nn.layers.normalization import BatchNorm, GroupNorm, LayerNorm, RMSNorm
from mlx.nn.layers.normalization import (
BatchNorm,
GroupNorm,
InstanceNorm,
LayerNorm,
RMSNorm,
)
from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding
from mlx.nn.layers.quantized import QuantizedLinear
from mlx.nn.layers.transformer import (

View File

@@ -6,6 +6,66 @@ import mlx.core as mx
from mlx.nn.layers.base import Module
class InstanceNorm(Module):
r"""Applies instance normalization [1] on the inputs.
Computes
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta,
where :math:`\gamma` and :math:`\beta` are learned per feature dimension
parameters initialized at 1 and 0 respectively. Both are of size :attr:`dims`,
if :attr:`affine` is ``True``.
Args:
dims (int): The number of features of the input.
eps (float): A value added to the denominator for numerical stability. Default: ``1e-5``.
affine (bool): Default: ``False``.
Shape:
- Input: :math:`(..., C)` where :math:`C` is equal to :attr:`dims`.
- Output: Same shape as the input.
Examples:
>>> import mlx.core as mx
>>> import mlx.nn as nn
>>> x = mx.random.normal((8, 4, 4, 16))
>>> inorm = nn.InstanceNorm(dims=16)
>>> output = inorm(x)
References:
[1]: https://arxiv.org/abs/1607.08022
"""
def __init__(
self,
dims: int,
eps: float = 1e-5,
affine: bool = False,
):
super().__init__()
if affine:
self.weight = mx.ones((dims,))
self.bias = mx.zeros((dims,))
self.dims = dims
self.eps = eps
def _extra_repr(self):
return f"{self.dims}, eps={self.eps}, affine={'weight' in self}"
def __call__(self, x: mx.array) -> mx.array:
reduction_axes = tuple(range(1, x.ndim - 1))
# Compute stats
mean = mx.mean(x, axis=reduction_axes, keepdims=True)
var = mx.var(x, axis=reduction_axes, keepdims=True)
# Normalize
x = (x - mean) * mx.rsqrt(var + self.eps)
# Scale and shift if necessary
return (self.weight * x + self.bias) if "weight" in self else x
class LayerNorm(Module):
r"""Applies layer normalization [1] on the inputs.