mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 10:26:56 +08:00
implemented InstanceNorm (#244)
* implemented instancenorm * implemented vector_norm in cpp added linalg to mlx * implemented vector_norm python binding * renamed vector_norm to norm, implemented norm without provided ord * completed the implementation of the norm * added tests * removed unused import in linalg.cpp * updated python bindings * added some tests for python bindings * handling inf, -inf as numpy does, more extensive tests of compatibility with numpy * added better docs and examples * refactored mlx.linalg.norm bindings * reused existing util for implementation of linalg.norm * more tests * fixed a bug with no ord and axis provided * removed unused imports * some style and API consistency updates to linalg norm * remove unused includes * fix python tests * fixed a bug with frobenius norm of a complex-valued matrix * complex for vector too * addressed PR review comments * fixed import order in __init__ * expected values in instancenorm tests are simple lists * minor return expression style change * added InstanceNorm to docs * doc string nits * added myself to individual contributors --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:

committed by
GitHub

parent
dff4a3833f
commit
c7edafb729
@@ -46,7 +46,13 @@ from mlx.nn.layers.convolution import Conv1d, Conv2d
|
||||
from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d
|
||||
from mlx.nn.layers.embedding import Embedding
|
||||
from mlx.nn.layers.linear import Bilinear, Identity, Linear
|
||||
from mlx.nn.layers.normalization import BatchNorm, GroupNorm, LayerNorm, RMSNorm
|
||||
from mlx.nn.layers.normalization import (
|
||||
BatchNorm,
|
||||
GroupNorm,
|
||||
InstanceNorm,
|
||||
LayerNorm,
|
||||
RMSNorm,
|
||||
)
|
||||
from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding
|
||||
from mlx.nn.layers.quantized import QuantizedLinear
|
||||
from mlx.nn.layers.transformer import (
|
||||
|
@@ -6,6 +6,66 @@ import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
|
||||
|
||||
class InstanceNorm(Module):
|
||||
r"""Applies instance normalization [1] on the inputs.
|
||||
|
||||
Computes
|
||||
|
||||
.. math::
|
||||
|
||||
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta,
|
||||
|
||||
where :math:`\gamma` and :math:`\beta` are learned per feature dimension
|
||||
parameters initialized at 1 and 0 respectively. Both are of size :attr:`dims`,
|
||||
if :attr:`affine` is ``True``.
|
||||
|
||||
Args:
|
||||
dims (int): The number of features of the input.
|
||||
eps (float): A value added to the denominator for numerical stability. Default: ``1e-5``.
|
||||
affine (bool): Default: ``False``.
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(..., C)` where :math:`C` is equal to :attr:`dims`.
|
||||
- Output: Same shape as the input.
|
||||
|
||||
Examples:
|
||||
>>> import mlx.core as mx
|
||||
>>> import mlx.nn as nn
|
||||
>>> x = mx.random.normal((8, 4, 4, 16))
|
||||
>>> inorm = nn.InstanceNorm(dims=16)
|
||||
>>> output = inorm(x)
|
||||
|
||||
References:
|
||||
[1]: https://arxiv.org/abs/1607.08022
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
eps: float = 1e-5,
|
||||
affine: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
if affine:
|
||||
self.weight = mx.ones((dims,))
|
||||
self.bias = mx.zeros((dims,))
|
||||
self.dims = dims
|
||||
self.eps = eps
|
||||
|
||||
def _extra_repr(self):
|
||||
return f"{self.dims}, eps={self.eps}, affine={'weight' in self}"
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
reduction_axes = tuple(range(1, x.ndim - 1))
|
||||
# Compute stats
|
||||
mean = mx.mean(x, axis=reduction_axes, keepdims=True)
|
||||
var = mx.var(x, axis=reduction_axes, keepdims=True)
|
||||
# Normalize
|
||||
x = (x - mean) * mx.rsqrt(var + self.eps)
|
||||
# Scale and shift if necessary
|
||||
return (self.weight * x + self.bias) if "weight" in self else x
|
||||
|
||||
|
||||
class LayerNorm(Module):
|
||||
r"""Applies layer normalization [1] on the inputs.
|
||||
|
||||
|
Reference in New Issue
Block a user