Update normalization.py

This commit is contained in:
__mo_san__ 2023-12-19 09:17:50 +01:00 committed by m0saan
parent a0b2a34e98
commit eca773b62c

View File

@ -283,39 +283,3 @@ class BatchNorm1d(Module):
means, var = self.running_mean, self.running_var
x = (x - means) * mx.rsqrt(var + self.eps)
return (self.weight * x + self.bias) if "weight" in self else x
# import unittest
# import numpy as np
# class TestBatchNorm1d(unittest.TestCase):
# def setUp(self):
# self.bn = BatchNorm1d(10)
# def test_forward(self):
# x = mx.random.uniform(shape=(20, 10))
# y = self.bn(x)
# expedted =
# self.assertEqual(y.shape, x.shape)
# def test_running_stats(self):
# x = mx.random.uniform(shape=(20, 10))
# self.bn(x)
# self.assertNotEqual(mx.sum(self.bn.running_mean), 0)
# self.assertNotEqual(mx.sum(self.bn.running_var), 10)
# def test_eval_mode(self):
# x = mx.random.uniform(shape=(20, 10))
# self.bn(x)
# self.bn.training = False
# y = self.bn(x)
# self.assertEqual(y.shape, x.shape)
# def test_no_affine(self):
# bn = BatchNorm1d(10, affine=False)
# x = mx.random.uniform(shape=(20, 10))
# y = bn(x)
# self.assertEqual(y.shape, x.shape)
# if __name__ == '__main__':
# unittest.main()