implemented InstanceNorm (#244)

* implemented instancenorm

* implemented vector_norm in cpp

added linalg to mlx

* implemented vector_norm python binding

* renamed vector_norm to norm, implemented norm without provided ord

* completed the implementation of the norm

* added tests

* removed unused import in linalg.cpp

* updated python bindings

* added some tests for python bindings

* handling inf, -inf as numpy does, more extensive tests of compatibility with numpy

* added better docs and examples

* refactored mlx.linalg.norm bindings

* reused existing util for implementation of linalg.norm

* more tests

* fixed a bug with no ord and axis provided

* removed unused imports

* some style and API consistency updates to linalg norm

* remove unused includes

* fix python tests

* fixed a bug with frobenius norm of a complex-valued matrix

* complex for vector too

* addressed PR review comments

* fixed import order in __init__

* expected values in instancenorm tests are simple lists

* minor return expression style change

* added InstanceNorm to docs

* doc string nits

* added myself to individual contributors

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Gabrijel Boduljak
2024-01-03 21:21:15 +01:00
committed by GitHub
parent dff4a3833f
commit c7edafb729
5 changed files with 287 additions and 1 deletions

View File

@@ -172,6 +172,224 @@ class TestLayers(mlx_tests.MLXTestCase):
self.assertTrue(np.allclose(means, 3 * np.ones_like(means), atol=1e-6))
self.assertTrue(np.allclose(var, 4 * np.ones_like(var), atol=1e-6))
def test_instance_norm(self):
# Test InstanceNorm1d
x = mx.array(
[
[
[-0.0119524, 1.1263, 2.02223],
[-0.500331, 0.517899, -1.21143],
[1.12958, -0.21413, -2.48738],
[1.39955, 0.891329, 1.63289],
],
[
[0.241417, -0.619157, -0.77484],
[-1.42512, 0.970817, -1.31352],
[2.739, -1.2506, 1.56844],
[-1.23175, 0.32756, 1.13969],
],
]
)
inorm = nn.InstanceNorm(dims=3)
y = inorm(x)
expected_y = [
[
[-0.657082, 1.07593, 1.0712],
[-1.27879, -0.123074, -0.632505],
[0.796101, -1.56572, -1.30476],
[1.13978, 0.612862, 0.866067],
],
[
[0.0964426, -0.557906, -0.759885],
[-0.904772, 1.30444, -1.20013],
[1.59693, -1.29752, 1.15521],
[-0.7886, 0.550987, 0.804807],
],
]
self.assertTrue(x.shape == y.shape)
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
# Test InstanceNorm2d
x = mx.array(
[
[
[
[-0.458824, 0.483254, -0.58611],
[-0.447996, -0.176577, -0.622545],
[0.0486988, -0.0611224, 1.8845],
],
[
[1.13049, 0.345315, -0.926389],
[0.301795, 0.99207, -0.184927],
[-2.23876, -0.758631, -1.12639],
],
[
[0.0986325, -1.82973, -0.241765],
[-1.25257, 0.154442, -0.556204],
[-0.329399, -0.319107, 0.830584],
],
],
[
[
[1.04407, 0.073752, 0.407081],
[0.0800776, 1.2513, 1.20627],
[0.782321, -0.444367, 0.563132],
],
[
[0.671423, -1.21689, -1.88979],
[-0.110299, -1.42248, 1.17838],
[0.159905, 0.516452, -0.539121],
],
[
[0.810252, 1.50456, 1.08659],
[0.182597, 0.0576239, 0.973883],
[-0.0621687, 0.184253, 0.784216],
],
],
]
)
inorm = nn.InstanceNorm(dims=3)
y = inorm(x)
expected_y = [
[
[
[-0.120422, 0.801503, -0.463983],
[-0.108465, -0.0608611, -0.504602],
[0.440008, 0.090032, 2.29032],
],
[
[1.63457, 0.621224, -0.843335],
[0.719488, 1.4665, -0.0167344],
[-2.08591, -0.821575, -1.0663],
],
[
[0.495147, -2.22145, -0.0800989],
[-0.996913, 0.371763, -0.430643],
[0.022495, -0.24714, 1.11538],
],
],
[
[
[1.5975, 0.0190292, -0.0123306],
[-0.776381, 1.28291, 0.817237],
[0.952927, -0.537076, 0.149652],
],
[
[0.679836, -1.36624, -2.39651],
[-1.24519, -1.5869, 0.788287],
[-0.579802, 0.494186, -0.994499],
],
[
[1.02171, 1.55474, 0.693008],
[-0.523922, 0.00171862, 0.576016],
[-1.12667, 0.137632, 0.37914],
],
],
]
self.assertTrue(x.shape == y.shape)
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
# # Test InstanceNorm3d
x = mx.array(
[
[
[
[[0.777621, 0.528145, -1.56133], [-2.1722, 0.128192, 0.153862]],
[
[-1.41317, 0.476288, -1.20411],
[0.284446, -0.649858, 0.152112],
],
],
[
[[0.11, -0.12431, 1.18768], [-0.837743, 1.93502, 0.00236324]],
[
[-2.40205, -1.25873, -2.04243],
[0.336682, -0.261986, 1.54289],
],
],
[
[
[0.789185, -1.63747, 0.67917],
[-1.42998, -1.73247, -0.402572],
],
[
[-0.459489, -2.15559, -0.249959],
[0.0298199, 0.10275, -0.821897],
],
],
],
[
[
[
[-2.12354, 0.643973, 0.72391],
[0.317797, -0.682916, 0.016364],
],
[
[-0.146628, -0.987925, 0.573199],
[0.0329215, 1.54086, 0.213092],
],
],
[
[
[-1.55784, 0.71179, -0.0678402],
[2.41031, -0.290786, 0.00449439],
],
[
[0.226341, 0.057712, -1.58342],
[0.265387, -0.742304, 1.28133],
],
],
[
[
[0.990317, -0.399875, -0.357647],
[0.475161, -1.10479, -1.07389],
],
[
[-1.37804, 1.40097, 0.141618],
[-0.501041, 0.0723374, -0.386141],
],
],
],
]
)
inorm = nn.InstanceNorm(dims=3)
y = inorm(x)
expected_y = [
[
[
[[1.23593, 0.821849, -1.30944], [-1.54739, 0.462867, 0.357126]],
[[-0.831204, 0.775304, -0.962338], [0.770588, -0.23548, 0.355425]],
],
[
[[0.605988, 0.236231, 1.36163], [-0.288258, 2.0846, 0.209922]],
[[-1.76427, -0.78198, -1.77689], [0.819875, 0.112659, 1.70677]],
],
[
[[1.24684, -1.12192, 0.867539], [-0.847068, -1.20719, -0.183531]],
[
[0.0686449, -1.58697, -0.0352458],
[0.530334, 0.440032, -0.590967],
],
],
],
[
[
[[-1.75315, 0.733967, 1.04349], [0.343736, -0.822472, 0.080661]],
[[-0.0551618, -1.18025, 0.838402], [0.0990544, 1.78602, 0.348368]],
],
[
[[-1.26726, 0.813517, -0.033924], [2.14101, -0.362504, 0.0645089]],
[[0.265184, 0.0462839, -2.09632], [0.298721, -0.892134, 1.80203]],
],
[
[[0.921369, -0.490465, -0.428293], [0.478897, -1.31732, -1.40296]],
[[-1.11283, 1.62192, 0.251107], [-0.35957, 0.0634394, -0.467067]],
],
],
]
self.assertTrue(x.shape == y.shape)
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
# Test repr
self.assertTrue(str(inorm) == "InstanceNorm(3, eps=1e-05, affine=False)")
def test_batch_norm(self):
mx.random.seed(42)
x = mx.random.normal((5, 4), dtype=mx.float32)