rebasing ...

This commit is contained in:
m0saan 2023-12-19 09:16:53 +01:00
parent d4bf9a2976
commit a0b2a34e98
2 changed files with 92 additions and 8 deletions

View File

@ -182,14 +182,6 @@ class GroupNorm(Module):
return (self.weight * x + self.bias) if "weight" in self else x return (self.weight * x + self.bias) if "weight" in self else x
# Copyright © 2023 Apple Inc.
from typing import Tuple
import mlx.core as mx
from mlx.nn.layers.base import Module
class BatchNorm1d(Module): class BatchNorm1d(Module):
r"""Applies Batch Normalization over a 2D or 3D input. r"""Applies Batch Normalization over a 2D or 3D input.
@ -291,3 +283,39 @@ class BatchNorm1d(Module):
means, var = self.running_mean, self.running_var means, var = self.running_mean, self.running_var
x = (x - means) * mx.rsqrt(var + self.eps) x = (x - means) * mx.rsqrt(var + self.eps)
return (self.weight * x + self.bias) if "weight" in self else x return (self.weight * x + self.bias) if "weight" in self else x
# import unittest
# import numpy as np
# class TestBatchNorm1d(unittest.TestCase):
# def setUp(self):
# self.bn = BatchNorm1d(10)
# def test_forward(self):
# x = mx.random.uniform(shape=(20, 10))
# y = self.bn(x)
# expedted =
# self.assertEqual(y.shape, x.shape)
# def test_running_stats(self):
# x = mx.random.uniform(shape=(20, 10))
# self.bn(x)
# self.assertNotEqual(mx.sum(self.bn.running_mean), 0)
# self.assertNotEqual(mx.sum(self.bn.running_var), 10)
# def test_eval_mode(self):
# x = mx.random.uniform(shape=(20, 10))
# self.bn(x)
# self.bn.training = False
# y = self.bn(x)
# self.assertEqual(y.shape, x.shape)
# def test_no_affine(self):
# bn = BatchNorm1d(10, affine=False)
# x = mx.random.uniform(shape=(20, 10))
# y = bn(x)
# self.assertEqual(y.shape, x.shape)
# if __name__ == '__main__':
# unittest.main()

View File

@ -320,6 +320,62 @@ class TestNN(mlx_tests.MLXTestCase):
self.assertTrue(np.allclose(means, 3 * np.ones_like(means), atol=1e-6)) self.assertTrue(np.allclose(means, 3 * np.ones_like(means), atol=1e-6))
self.assertTrue(np.allclose(var, 4 * np.ones_like(var), atol=1e-6)) self.assertTrue(np.allclose(var, 4 * np.ones_like(var), atol=1e-6))
def test_batch_norm(self):
mx.random.seed(42)
x = mx.random.normal((5, 4), dtype=mx.float32)
# Batch norm
bn = nn.BatchNorm1d(num_features=4, affine=True)
self.assertTrue(mx.allclose(bn.running_mean, mx.zeros_like(bn.running_mean)))
self.assertTrue(mx.allclose(bn.running_var, mx.ones_like(bn.running_var)))
y = bn(x)
expected_y = mx.array(
[
[-0.439520, 1.647328, -0.955515, 1.966031],
[-1.726690, -1.449826, -0.234026, -0.723364],
[0.938414, -0.349603, -0.354470, -0.175369],
[0.305006, 0.234914, -0.393017, -0.459385],
[0.922789, -0.082813, 1.937028, -0.607913],
],
)
expected_mean = mx.array([0.008929, 0.005680, -0.016092, 0.027778])
expected_var = mx.array([0.928435, 1.00455, 1.04117, 0.94258])
self.assertTrue(x.shape == y.shape)
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
self.assertTrue(np.allclose(bn.running_mean, expected_mean, atol=1e-5))
self.assertTrue(np.allclose(bn.running_var, expected_var, atol=1e-5))
# test eval mode
bn.eval()
y = bn(x)
expected_y = mx.array(
[
[-0.15984, 1.73159, -1.25456, 1.57891],
[-0.872193, -1.4281, -0.414439, -0.228678],
[0.602743, -0.30566, -0.554687, 0.139639],
[0.252199, 0.29066, -0.599572, -0.0512532],
[0.594096, -0.0334829, 2.11359, -0.151081],
]
)
self.assertTrue(x.shape == y.shape)
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
# test_no_affine
bn = nn.BatchNorm1d(num_features=4, affine=False)
y = bn(x)
expected_y = mx.array(
[
[-0.439520, 1.647328, -0.955515, 1.966031],
[-1.726690, -1.449826, -0.234026, -0.723364],
[0.938414, -0.349603, -0.354470, -0.175369],
[0.305006, 0.234914, -0.393017, -0.459385],
[0.922789, -0.082813, 1.937028, -0.607913],
]
)
self.assertTrue(x.shape == y.shape)
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
def test_conv1d(self): def test_conv1d(self):
N = 5 N = 5
L = 12 L = 12