From a0b2a34e9871824dce26a371a38e7fa1ba2258e4 Mon Sep 17 00:00:00 2001 From: m0saan Date: Tue, 19 Dec 2023 09:16:53 +0100 Subject: [PATCH] rebasing ... --- python/mlx/nn/layers/normalization.py | 44 +++++++++++++++++---- python/tests/test_nn.py | 56 +++++++++++++++++++++++++++ 2 files changed, 92 insertions(+), 8 deletions(-) diff --git a/python/mlx/nn/layers/normalization.py b/python/mlx/nn/layers/normalization.py index 72058aae3..92847e423 100644 --- a/python/mlx/nn/layers/normalization.py +++ b/python/mlx/nn/layers/normalization.py @@ -182,14 +182,6 @@ class GroupNorm(Module): return (self.weight * x + self.bias) if "weight" in self else x -# Copyright © 2023 Apple Inc. - -from typing import Tuple - -import mlx.core as mx -from mlx.nn.layers.base import Module - - class BatchNorm1d(Module): r"""Applies Batch Normalization over a 2D or 3D input. @@ -291,3 +283,39 @@ class BatchNorm1d(Module): means, var = self.running_mean, self.running_var x = (x - means) * mx.rsqrt(var + self.eps) return (self.weight * x + self.bias) if "weight" in self else x + + +# import unittest +# import numpy as np + +# class TestBatchNorm1d(unittest.TestCase): +# def setUp(self): +# self.bn = BatchNorm1d(10) + +# def test_forward(self): +# x = mx.random.uniform(shape=(20, 10)) +# y = self.bn(x) +# expedted = +# self.assertEqual(y.shape, x.shape) + +# def test_running_stats(self): +# x = mx.random.uniform(shape=(20, 10)) +# self.bn(x) +# self.assertNotEqual(mx.sum(self.bn.running_mean), 0) +# self.assertNotEqual(mx.sum(self.bn.running_var), 10) + +# def test_eval_mode(self): +# x = mx.random.uniform(shape=(20, 10)) +# self.bn(x) +# self.bn.training = False +# y = self.bn(x) +# self.assertEqual(y.shape, x.shape) + +# def test_no_affine(self): +# bn = BatchNorm1d(10, affine=False) +# x = mx.random.uniform(shape=(20, 10)) +# y = bn(x) +# self.assertEqual(y.shape, x.shape) + +# if __name__ == '__main__': +# unittest.main() diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index cc56bc430..63805ab44 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -320,6 +320,62 @@ class TestNN(mlx_tests.MLXTestCase): self.assertTrue(np.allclose(means, 3 * np.ones_like(means), atol=1e-6)) self.assertTrue(np.allclose(var, 4 * np.ones_like(var), atol=1e-6)) + def test_batch_norm(self): + mx.random.seed(42) + x = mx.random.normal((5, 4), dtype=mx.float32) + + # Batch norm + bn = nn.BatchNorm1d(num_features=4, affine=True) + self.assertTrue(mx.allclose(bn.running_mean, mx.zeros_like(bn.running_mean))) + self.assertTrue(mx.allclose(bn.running_var, mx.ones_like(bn.running_var))) + y = bn(x) + expected_y = mx.array( + [ + [-0.439520, 1.647328, -0.955515, 1.966031], + [-1.726690, -1.449826, -0.234026, -0.723364], + [0.938414, -0.349603, -0.354470, -0.175369], + [0.305006, 0.234914, -0.393017, -0.459385], + [0.922789, -0.082813, 1.937028, -0.607913], + ], + ) + expected_mean = mx.array([0.008929, 0.005680, -0.016092, 0.027778]) + expected_var = mx.array([0.928435, 1.00455, 1.04117, 0.94258]) + self.assertTrue(x.shape == y.shape) + self.assertTrue(np.allclose(y, expected_y, atol=1e-5)) + self.assertTrue(np.allclose(bn.running_mean, expected_mean, atol=1e-5)) + self.assertTrue(np.allclose(bn.running_var, expected_var, atol=1e-5)) + + # test eval mode + bn.eval() + y = bn(x) + expected_y = mx.array( + [ + [-0.15984, 1.73159, -1.25456, 1.57891], + [-0.872193, -1.4281, -0.414439, -0.228678], + [0.602743, -0.30566, -0.554687, 0.139639], + [0.252199, 0.29066, -0.599572, -0.0512532], + [0.594096, -0.0334829, 2.11359, -0.151081], + ] + ) + + self.assertTrue(x.shape == y.shape) + self.assertTrue(np.allclose(y, expected_y, atol=1e-5)) + + # test_no_affine + bn = nn.BatchNorm1d(num_features=4, affine=False) + y = bn(x) + expected_y = mx.array( + [ + [-0.439520, 1.647328, -0.955515, 1.966031], + [-1.726690, -1.449826, -0.234026, -0.723364], + [0.938414, -0.349603, -0.354470, -0.175369], + [0.305006, 0.234914, -0.393017, -0.459385], + [0.922789, -0.082813, 1.937028, -0.607913], + ] + ) + self.assertTrue(x.shape == y.shape) + self.assertTrue(np.allclose(y, expected_y, atol=1e-5)) + def test_conv1d(self): N = 5 L = 12