mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
Update normalization.py
This commit is contained in:
parent
a0b2a34e98
commit
eca773b62c
@ -283,39 +283,3 @@ class BatchNorm1d(Module):
|
||||
means, var = self.running_mean, self.running_var
|
||||
x = (x - means) * mx.rsqrt(var + self.eps)
|
||||
return (self.weight * x + self.bias) if "weight" in self else x
|
||||
|
||||
|
||||
# import unittest
|
||||
# import numpy as np
|
||||
|
||||
# class TestBatchNorm1d(unittest.TestCase):
|
||||
# def setUp(self):
|
||||
# self.bn = BatchNorm1d(10)
|
||||
|
||||
# def test_forward(self):
|
||||
# x = mx.random.uniform(shape=(20, 10))
|
||||
# y = self.bn(x)
|
||||
# expedted =
|
||||
# self.assertEqual(y.shape, x.shape)
|
||||
|
||||
# def test_running_stats(self):
|
||||
# x = mx.random.uniform(shape=(20, 10))
|
||||
# self.bn(x)
|
||||
# self.assertNotEqual(mx.sum(self.bn.running_mean), 0)
|
||||
# self.assertNotEqual(mx.sum(self.bn.running_var), 10)
|
||||
|
||||
# def test_eval_mode(self):
|
||||
# x = mx.random.uniform(shape=(20, 10))
|
||||
# self.bn(x)
|
||||
# self.bn.training = False
|
||||
# y = self.bn(x)
|
||||
# self.assertEqual(y.shape, x.shape)
|
||||
|
||||
# def test_no_affine(self):
|
||||
# bn = BatchNorm1d(10, affine=False)
|
||||
# x = mx.random.uniform(shape=(20, 10))
|
||||
# y = bn(x)
|
||||
# self.assertEqual(y.shape, x.shape)
|
||||
|
||||
# if __name__ == '__main__':
|
||||
# unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user