mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
implemented instancenorm
This commit is contained in:
parent
f24200db2c
commit
6c7bebf428
@ -36,7 +36,7 @@ from mlx.nn.layers.convolution import Conv1d, Conv2d
|
||||
from mlx.nn.layers.dropout import Dropout
|
||||
from mlx.nn.layers.embedding import Embedding
|
||||
from mlx.nn.layers.linear import Linear
|
||||
from mlx.nn.layers.normalization import GroupNorm, LayerNorm, RMSNorm
|
||||
from mlx.nn.layers.normalization import GroupNorm, InstanceNorm, LayerNorm, RMSNorm
|
||||
from mlx.nn.layers.positional_encoding import RoPE, SinusoidalPositionalEncoding
|
||||
from mlx.nn.layers.transformer import (
|
||||
MultiHeadAttention,
|
||||
|
@ -4,6 +4,80 @@ import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
|
||||
|
||||
class InstanceNorm(Module):
|
||||
r"""Applies instance normalization [1] on the inputs.
|
||||
|
||||
Computes
|
||||
|
||||
.. math::
|
||||
|
||||
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta,
|
||||
|
||||
where :math:`\gamma` and :math:`\beta` are learned per feature dimension
|
||||
parameters initialized at 1 and 0 respectively. Both are of size num_features,
|
||||
if :attr:`affine` is ``True``.
|
||||
|
||||
[1]: https://arxiv.org/abs/1607.08022
|
||||
|
||||
Args:
|
||||
num_features: number of features of the input
|
||||
eps: a value added to the denominator for numerical stability. Default: 1e-5
|
||||
momentum: the value used for the running_mean and running_var computation. Default: 0.1
|
||||
affine: a boolean value that when set to ``True``, this module has
|
||||
learnable affine parameters, initialized the same way as done for batch normalization.
|
||||
Default: ``False``.
|
||||
Shape:
|
||||
- Input: :math:`(N, C, L)` or :math:`(N, C, H, W)` or :math:`(N, C, D, H, W)`
|
||||
- Output: :math:`(N, C, L)` or :math:`(N, C, H, W)` or :math:`(N, C, D, H, W)` (same shape as input)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_features: int,
|
||||
eps: float = 1e-5,
|
||||
affine: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
if affine:
|
||||
self.weight = mx.ones((num_features,))
|
||||
self.bias = mx.zeros((num_features,))
|
||||
self.eps = eps
|
||||
self.affine = affine
|
||||
self.param_shapes = {
|
||||
3: (1, num_features, 1), # input shape: (B, C, L)
|
||||
4: (1, num_features, 1, 1), # input shape: (B, C, H, W)
|
||||
5: (1, num_features, 1, 1, 1), # input shape: (B, C, D, H, W)
|
||||
}
|
||||
self.reduction_axes = {
|
||||
3: [2], # input shape: (B, C, L)
|
||||
4: [2, 3], # input shape: (B, C, H, W)
|
||||
5: [2, 3, 4], # input shape: (B, C, D, H, W)
|
||||
}
|
||||
|
||||
def extra_repr(self):
|
||||
return "{num_features}, eps={eps}, momentum={momentum}, affine={affine}".format(
|
||||
**self.__dict__
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
if self.affine and self.weight.ndim != x.ndim:
|
||||
if x.ndim not in self.reduction_axes:
|
||||
raise ValueError("Unsupported number shape")
|
||||
# Ensure parameters are reshaped for correct broadcasting
|
||||
self.weight = mx.reshape(self.weight, self.param_shapes[x.ndim])
|
||||
self.bias = mx.reshape(self.bias, self.param_shapes[x.ndim])
|
||||
# Compute stats
|
||||
mean = mx.mean(x, axis=self.reduction_axes[x.ndim], keepdims=True)
|
||||
var = mx.var(x, axis=self.reduction_axes[x.ndim], keepdims=True)
|
||||
# Normalize
|
||||
normalized = (x - mean) * mx.rsqrt(var + self.eps)
|
||||
# Scale and shift if necessary
|
||||
if self.affine:
|
||||
return self.weight * normalized + self.bias
|
||||
else:
|
||||
return normalized
|
||||
|
||||
|
||||
class LayerNorm(Module):
|
||||
r"""Applies layer normalization [1] on the inputs.
|
||||
|
||||
|
@ -226,6 +226,197 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(np.allclose(means, 3 * np.ones_like(means), atol=1e-6))
|
||||
self.assertTrue(np.allclose(var, 4 * np.ones_like(var), atol=1e-6))
|
||||
|
||||
def test_instance_norm(self):
|
||||
# Test InstanceNorm1d
|
||||
x = mx.array(
|
||||
[
|
||||
[
|
||||
[-0.0119524, -0.500331, 1.12958, 1.39955],
|
||||
[1.1263, 0.517899, -0.21413, 0.891329],
|
||||
[2.02223, -1.21143, -2.48738, 1.63289],
|
||||
],
|
||||
[
|
||||
[0.241417, -1.42512, 2.739, -1.23175],
|
||||
[-0.619157, 0.970817, -1.2506, 0.32756],
|
||||
[-0.77484, -1.31352, 1.56844, 1.13969],
|
||||
],
|
||||
]
|
||||
)
|
||||
inorm = nn.InstanceNorm(num_features=3)
|
||||
y = inorm(x)
|
||||
expected_y = [
|
||||
[
|
||||
[-0.657082, -1.27879, 0.796097, 1.13978],
|
||||
[1.07593, -0.123075, -1.56572, 0.61286],
|
||||
[1.0712, -0.632503, -1.30476, 0.866066],
|
||||
],
|
||||
[
|
||||
[0.0964433, -0.904773, 1.59693, -0.788599],
|
||||
[-0.557908, 1.30444, -1.29751, 0.550987],
|
||||
[-0.759886, -1.20013, 1.15521, 0.804804],
|
||||
],
|
||||
]
|
||||
self.assertTrue(x.shape == y.shape)
|
||||
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
|
||||
# Test InstanceNorm2d
|
||||
x = mx.array(
|
||||
[
|
||||
[
|
||||
[
|
||||
[-0.458824, -0.447996, 0.0486988],
|
||||
[1.13049, 0.301795, -2.23876],
|
||||
[0.0986325, -1.25257, -0.329399],
|
||||
],
|
||||
[
|
||||
[0.483254, -0.176577, -0.0611224],
|
||||
[0.345315, 0.99207, -0.758631],
|
||||
[-1.82973, 0.154442, -0.319107],
|
||||
],
|
||||
[
|
||||
[-0.58611, -0.622545, 1.8845],
|
||||
[-0.926389, -0.184927, -1.12639],
|
||||
[-0.241765, -0.556204, 0.830584],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[1.04407, 0.0800776, 0.782321],
|
||||
[0.671423, -0.110299, 0.159905],
|
||||
[0.810252, 0.182597, -0.0621687],
|
||||
],
|
||||
[
|
||||
[0.073752, 1.2513, -0.444367],
|
||||
[-1.21689, -1.42248, 0.516452],
|
||||
[1.50456, 0.0576239, 0.184253],
|
||||
],
|
||||
[
|
||||
[0.407081, 1.20627, 0.563132],
|
||||
[-1.88979, 1.17838, -0.539121],
|
||||
[1.08659, 0.973883, 0.784216],
|
||||
],
|
||||
],
|
||||
]
|
||||
)
|
||||
inorm = nn.InstanceNorm(num_features=3)
|
||||
y = inorm(x)
|
||||
expected_y = [
|
||||
[
|
||||
[
|
||||
[-0.120422, -0.108465, 0.440008],
|
||||
[1.63457, 0.719488, -2.08591],
|
||||
[0.495147, -0.996913, 0.0224944],
|
||||
],
|
||||
[
|
||||
[0.801504, -0.0608616, 0.0900314],
|
||||
[0.621224, 1.4665, -0.821576],
|
||||
[-2.22144, 0.371763, -0.247141],
|
||||
],
|
||||
[
|
||||
[-0.463984, -0.504602, 2.29032],
|
||||
[-0.843336, -0.0167355, -1.0663],
|
||||
[-0.0800997, -0.430644, 1.11538],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[1.59749, -0.776381, 0.95293],
|
||||
[0.679838, -1.24519, -0.579803],
|
||||
[1.02171, -0.523923, -1.12667],
|
||||
],
|
||||
[
|
||||
[0.0190289, 1.28291, -0.537076],
|
||||
[-1.36624, -1.5869, 0.494185],
|
||||
[1.55474, 0.00171834, 0.137631],
|
||||
],
|
||||
[
|
||||
[-0.012331, 0.817234, 0.149652],
|
||||
[-2.39651, 0.78829, -0.994498],
|
||||
[0.693007, 0.576016, 0.37914],
|
||||
],
|
||||
],
|
||||
]
|
||||
self.assertTrue(x.shape == y.shape)
|
||||
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
|
||||
# Test InstanceNorm3d
|
||||
x = mx.array(
|
||||
[
|
||||
[
|
||||
[
|
||||
[[0.777621, -2.1722], [-1.41317, 0.284446]],
|
||||
[[0.11, -0.837743], [-2.40205, 0.336682]],
|
||||
[[0.789185, -1.42998], [-0.459489, 0.0298199]],
|
||||
],
|
||||
[
|
||||
[[0.528145, 0.128192], [0.476288, -0.649858]],
|
||||
[[-0.12431, 1.93502], [-1.25873, -0.261986]],
|
||||
[[-1.63747, -1.73247], [-2.15559, 0.10275]],
|
||||
],
|
||||
[
|
||||
[[-1.56133, 0.153862], [-1.20411, 0.152112]],
|
||||
[[1.18768, 0.00236324], [-2.04243, 1.54289]],
|
||||
[[0.67917, -0.402572], [-0.249959, -0.821897]],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[[-2.12354, 0.317797], [-0.146628, 0.0329215]],
|
||||
[[-1.55784, 2.41031], [0.226341, 0.265387]],
|
||||
[[0.990317, 0.475161], [-1.37804, -0.501041]],
|
||||
],
|
||||
[
|
||||
[[0.643973, -0.682916], [-0.987925, 1.54086]],
|
||||
[[0.71179, -0.290786], [0.057712, -0.742304]],
|
||||
[[-0.399875, -1.10479], [1.40097, 0.0723374]],
|
||||
],
|
||||
[
|
||||
[[0.72391, 0.016364], [0.573199, 0.213092]],
|
||||
[[-0.0678402, 0.00449439], [-1.58342, 1.28133]],
|
||||
[[-0.357647, -1.07389], [0.141618, -0.386141]],
|
||||
],
|
||||
],
|
||||
]
|
||||
)
|
||||
inorm = nn.InstanceNorm(num_features=3)
|
||||
y = inorm(x)
|
||||
expected_y = [
|
||||
[
|
||||
[
|
||||
[[1.23593, -1.54739], [-0.831204, 0.770588]],
|
||||
[[0.605988, -0.288258], [-1.76427, 0.819875]],
|
||||
[[1.24684, -0.847068], [0.0686449, 0.530334]],
|
||||
],
|
||||
[
|
||||
[[0.821849, 0.462867], [0.775304, -0.23548]],
|
||||
[[0.236231, 2.0846], [-0.78198, 0.112659]],
|
||||
[[-1.12192, -1.20719], [-1.58697, 0.440032]],
|
||||
],
|
||||
[
|
||||
[[-1.30944, 0.357126], [-0.962338, 0.355425]],
|
||||
[[1.36163, 0.209922], [-1.77689, 1.70677]],
|
||||
[[0.867539, -0.183531], [-0.0352458, -0.590967]],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[[-1.75315, 0.343736], [-0.0551618, 0.0990544]],
|
||||
[[-1.26726, 2.14101], [0.265184, 0.298721]],
|
||||
[[0.921369, 0.478897], [-1.11283, -0.35957]],
|
||||
],
|
||||
[
|
||||
[[0.733967, -0.822472], [-1.18025, 1.78602]],
|
||||
[[0.813517, -0.362504], [0.0462839, -0.892134]],
|
||||
[[-0.490465, -1.31732], [1.62192, 0.0634394]],
|
||||
],
|
||||
[
|
||||
[[1.04349, 0.080661], [0.838402, 0.348368]],
|
||||
[[-0.033924, 0.0645089], [-2.09632, 1.80203]],
|
||||
[[-0.428293, -1.40296], [0.251107, -0.467067]],
|
||||
],
|
||||
],
|
||||
]
|
||||
self.assertTrue(x.shape == y.shape)
|
||||
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
|
||||
|
||||
def test_conv1d(self):
|
||||
N = 5
|
||||
L = 12
|
||||
|
Loading…
Reference in New Issue
Block a user