implement-batch-norm-layer (#217)

- Add batch normalization layer

---------

Co-authored-by: Robert McCraith <mccraithrobert@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
__mo_san__
2023-12-25 16:32:53 +01:00
committed by GitHub
parent 9e6b8c9f48
commit a123c3c7d2
6 changed files with 267 additions and 11 deletions

View File

@@ -320,6 +320,143 @@ class TestNN(mlx_tests.MLXTestCase):
self.assertTrue(np.allclose(means, 3 * np.ones_like(means), atol=1e-6))
self.assertTrue(np.allclose(var, 4 * np.ones_like(var), atol=1e-6))
def test_batch_norm(self):
mx.random.seed(42)
x = mx.random.normal((5, 4), dtype=mx.float32)
# Batch norm
bn = nn.BatchNorm(num_features=4, affine=True)
self.assertTrue(mx.allclose(bn._running_mean, mx.zeros_like(bn._running_mean)))
self.assertTrue(mx.allclose(bn._running_var, mx.ones_like(bn._running_var)))
y = bn(x)
expected_y = mx.array(
[
[-0.439520, 1.647328, -0.955515, 1.966031],
[-1.726690, -1.449826, -0.234026, -0.723364],
[0.938414, -0.349603, -0.354470, -0.175369],
[0.305006, 0.234914, -0.393017, -0.459385],
[0.922789, -0.082813, 1.937028, -0.607913],
],
)
expected_mean = mx.array([0.008929, 0.005680, -0.016092, 0.027778])
expected_var = mx.array([0.928435, 1.00455, 1.04117, 0.94258])
self.assertTrue(x.shape == y.shape)
self.assertTrue(mx.allclose(y, expected_y, atol=1e-5))
self.assertTrue(mx.allclose(bn._running_mean, expected_mean, atol=1e-5))
self.assertTrue(mx.allclose(bn._running_var, expected_var, atol=1e-5))
# test eval mode
bn.eval()
y = bn(x)
expected_y = mx.array(
[
[-0.15984, 1.73159, -1.25456, 1.57891],
[-0.872193, -1.4281, -0.414439, -0.228678],
[0.602743, -0.30566, -0.554687, 0.139639],
[0.252199, 0.29066, -0.599572, -0.0512532],
[0.594096, -0.0334829, 2.11359, -0.151081],
]
)
self.assertTrue(x.shape == y.shape)
self.assertTrue(mx.allclose(y, expected_y, atol=1e-5))
# test_no_affine
bn = nn.BatchNorm(num_features=4, affine=False)
y = bn(x)
expected_y = mx.array(
[
[-0.439520, 1.647328, -0.955515, 1.966031],
[-1.726690, -1.449826, -0.234026, -0.723364],
[0.938414, -0.349603, -0.354470, -0.175369],
[0.305006, 0.234914, -0.393017, -0.459385],
[0.922789, -0.082813, 1.937028, -0.607913],
]
)
self.assertTrue(x.shape == y.shape)
self.assertTrue(mx.allclose(y, expected_y, atol=1e-5))
# test with 3D input
mx.random.seed(42)
N = 2
L = 4
C = 5
x = mx.random.normal((N, L, C), dtype=mx.float32)
# Batch norm
bn = nn.BatchNorm(num_features=C, affine=True)
self.assertTrue(mx.allclose(bn._running_mean, mx.zeros_like(bn._running_mean)))
self.assertTrue(mx.allclose(bn._running_var, mx.ones_like(bn._running_var)))
y = bn(x)
self.assertTrue(x.shape == y.shape)
expected_y = mx.array(
[
[
[-0.335754, 0.342054, 1.02653, 0.628588, -1.63899],
[1.92092, 0.432319, 0.343043, 1.95489, 1.0696],
[-0.853748, 1.3661, 0.868569, 0.0199196, -0.887284],
[0.459206, -0.684822, -0.706354, -0.271531, 0.566341],
],
[
[-0.921179, 0.684951, -0.77466, -0.490372, -0.247032],
[1.10839, -2.13179, 0.628924, -1.62639, -0.539708],
[-0.348943, 0.412194, -2.03818, 0.524972, 1.64568],
[-1.02889, -0.421, 0.652127, -0.740079, 0.0313996],
],
]
)
self.assertTrue(mx.allclose(y, expected_y, atol=1e-5))
expected_mean = mx.array(
[[[0.00207845, -5.3259e-05, 0.04755, -0.0697296, 0.0236228]]]
)
expected_var = mx.array([[[0.968415, 1.05322, 0.96913, 0.932305, 0.967224]]])
self.assertTrue(mx.allclose(bn._running_mean, expected_mean, atol=1e-5))
self.assertTrue(mx.allclose(bn._running_var, expected_var, atol=1e-5))
x = mx.random.normal((N, L, C, L, C), dtype=mx.float32)
with self.assertRaises(ValueError):
y = bn(x)
def test_batch_norm_stats(self):
batch_size = 2
num_features = 4
h = 3
w = 3
momentum = 0.1
batch_norm = nn.BatchNorm(num_features)
batch_norm.train()
running_mean = np.array(batch_norm._running_mean)
running_var = np.array(batch_norm._running_var)
data = mx.random.normal((batch_size, num_features))
normalized_data = batch_norm(data)
np_data = np.array(data)
means = np.mean(np_data, axis=0)
variances = np.var(np_data, axis=0)
running_mean = (1 - momentum) * running_mean + momentum * means
running_var = (1 - momentum) * running_var + momentum * variances
self.assertTrue(np.allclose(batch_norm._running_mean, running_mean, atol=1e-5))
self.assertTrue(np.allclose(batch_norm._running_var, running_var, atol=1e-5))
batch_norm = nn.BatchNorm(num_features)
batch_norm.train()
running_mean = np.array(batch_norm._running_mean)
running_var = np.array(batch_norm._running_var)
data = mx.random.normal((batch_size, h, w, num_features))
normalized_data = batch_norm(data)
np_data = np.array(data)
means = np.mean(np_data, axis=(0, 1, 2))
variances = np.var(np_data, axis=(0, 1, 2))
running_mean = (1 - momentum) * running_mean + momentum * means
running_var = (1 - momentum) * running_var + momentum * variances
self.assertTrue(np.allclose(batch_norm._running_mean, running_mean, atol=1e-5))
self.assertTrue(np.allclose(batch_norm._running_var, running_var, atol=1e-5))
def test_conv1d(self):
N = 5
L = 12