diff --git a/python/mlx/nn/layers/normalization.py b/python/mlx/nn/layers/normalization.py index 92847e423..4f7ad7b60 100644 --- a/python/mlx/nn/layers/normalization.py +++ b/python/mlx/nn/layers/normalization.py @@ -283,39 +283,3 @@ class BatchNorm1d(Module): means, var = self.running_mean, self.running_var x = (x - means) * mx.rsqrt(var + self.eps) return (self.weight * x + self.bias) if "weight" in self else x - - -# import unittest -# import numpy as np - -# class TestBatchNorm1d(unittest.TestCase): -# def setUp(self): -# self.bn = BatchNorm1d(10) - -# def test_forward(self): -# x = mx.random.uniform(shape=(20, 10)) -# y = self.bn(x) -# expedted = -# self.assertEqual(y.shape, x.shape) - -# def test_running_stats(self): -# x = mx.random.uniform(shape=(20, 10)) -# self.bn(x) -# self.assertNotEqual(mx.sum(self.bn.running_mean), 0) -# self.assertNotEqual(mx.sum(self.bn.running_var), 10) - -# def test_eval_mode(self): -# x = mx.random.uniform(shape=(20, 10)) -# self.bn(x) -# self.bn.training = False -# y = self.bn(x) -# self.assertEqual(y.shape, x.shape) - -# def test_no_affine(self): -# bn = BatchNorm1d(10, affine=False) -# x = mx.random.uniform(shape=(20, 10)) -# y = bn(x) -# self.assertEqual(y.shape, x.shape) - -# if __name__ == '__main__': -# unittest.main()