mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
rebasing ...
This commit is contained in:
parent
d4bf9a2976
commit
a0b2a34e98
@ -182,14 +182,6 @@ class GroupNorm(Module):
|
|||||||
return (self.weight * x + self.bias) if "weight" in self else x
|
return (self.weight * x + self.bias) if "weight" in self else x
|
||||||
|
|
||||||
|
|
||||||
# Copyright © 2023 Apple Inc.
|
|
||||||
|
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
from mlx.nn.layers.base import Module
|
|
||||||
|
|
||||||
|
|
||||||
class BatchNorm1d(Module):
|
class BatchNorm1d(Module):
|
||||||
r"""Applies Batch Normalization over a 2D or 3D input.
|
r"""Applies Batch Normalization over a 2D or 3D input.
|
||||||
|
|
||||||
@ -291,3 +283,39 @@ class BatchNorm1d(Module):
|
|||||||
means, var = self.running_mean, self.running_var
|
means, var = self.running_mean, self.running_var
|
||||||
x = (x - means) * mx.rsqrt(var + self.eps)
|
x = (x - means) * mx.rsqrt(var + self.eps)
|
||||||
return (self.weight * x + self.bias) if "weight" in self else x
|
return (self.weight * x + self.bias) if "weight" in self else x
|
||||||
|
|
||||||
|
|
||||||
|
# import unittest
|
||||||
|
# import numpy as np
|
||||||
|
|
||||||
|
# class TestBatchNorm1d(unittest.TestCase):
|
||||||
|
# def setUp(self):
|
||||||
|
# self.bn = BatchNorm1d(10)
|
||||||
|
|
||||||
|
# def test_forward(self):
|
||||||
|
# x = mx.random.uniform(shape=(20, 10))
|
||||||
|
# y = self.bn(x)
|
||||||
|
# expedted =
|
||||||
|
# self.assertEqual(y.shape, x.shape)
|
||||||
|
|
||||||
|
# def test_running_stats(self):
|
||||||
|
# x = mx.random.uniform(shape=(20, 10))
|
||||||
|
# self.bn(x)
|
||||||
|
# self.assertNotEqual(mx.sum(self.bn.running_mean), 0)
|
||||||
|
# self.assertNotEqual(mx.sum(self.bn.running_var), 10)
|
||||||
|
|
||||||
|
# def test_eval_mode(self):
|
||||||
|
# x = mx.random.uniform(shape=(20, 10))
|
||||||
|
# self.bn(x)
|
||||||
|
# self.bn.training = False
|
||||||
|
# y = self.bn(x)
|
||||||
|
# self.assertEqual(y.shape, x.shape)
|
||||||
|
|
||||||
|
# def test_no_affine(self):
|
||||||
|
# bn = BatchNorm1d(10, affine=False)
|
||||||
|
# x = mx.random.uniform(shape=(20, 10))
|
||||||
|
# y = bn(x)
|
||||||
|
# self.assertEqual(y.shape, x.shape)
|
||||||
|
|
||||||
|
# if __name__ == '__main__':
|
||||||
|
# unittest.main()
|
||||||
|
@ -320,6 +320,62 @@ class TestNN(mlx_tests.MLXTestCase):
|
|||||||
self.assertTrue(np.allclose(means, 3 * np.ones_like(means), atol=1e-6))
|
self.assertTrue(np.allclose(means, 3 * np.ones_like(means), atol=1e-6))
|
||||||
self.assertTrue(np.allclose(var, 4 * np.ones_like(var), atol=1e-6))
|
self.assertTrue(np.allclose(var, 4 * np.ones_like(var), atol=1e-6))
|
||||||
|
|
||||||
|
def test_batch_norm(self):
|
||||||
|
mx.random.seed(42)
|
||||||
|
x = mx.random.normal((5, 4), dtype=mx.float32)
|
||||||
|
|
||||||
|
# Batch norm
|
||||||
|
bn = nn.BatchNorm1d(num_features=4, affine=True)
|
||||||
|
self.assertTrue(mx.allclose(bn.running_mean, mx.zeros_like(bn.running_mean)))
|
||||||
|
self.assertTrue(mx.allclose(bn.running_var, mx.ones_like(bn.running_var)))
|
||||||
|
y = bn(x)
|
||||||
|
expected_y = mx.array(
|
||||||
|
[
|
||||||
|
[-0.439520, 1.647328, -0.955515, 1.966031],
|
||||||
|
[-1.726690, -1.449826, -0.234026, -0.723364],
|
||||||
|
[0.938414, -0.349603, -0.354470, -0.175369],
|
||||||
|
[0.305006, 0.234914, -0.393017, -0.459385],
|
||||||
|
[0.922789, -0.082813, 1.937028, -0.607913],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
expected_mean = mx.array([0.008929, 0.005680, -0.016092, 0.027778])
|
||||||
|
expected_var = mx.array([0.928435, 1.00455, 1.04117, 0.94258])
|
||||||
|
self.assertTrue(x.shape == y.shape)
|
||||||
|
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
|
||||||
|
self.assertTrue(np.allclose(bn.running_mean, expected_mean, atol=1e-5))
|
||||||
|
self.assertTrue(np.allclose(bn.running_var, expected_var, atol=1e-5))
|
||||||
|
|
||||||
|
# test eval mode
|
||||||
|
bn.eval()
|
||||||
|
y = bn(x)
|
||||||
|
expected_y = mx.array(
|
||||||
|
[
|
||||||
|
[-0.15984, 1.73159, -1.25456, 1.57891],
|
||||||
|
[-0.872193, -1.4281, -0.414439, -0.228678],
|
||||||
|
[0.602743, -0.30566, -0.554687, 0.139639],
|
||||||
|
[0.252199, 0.29066, -0.599572, -0.0512532],
|
||||||
|
[0.594096, -0.0334829, 2.11359, -0.151081],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(x.shape == y.shape)
|
||||||
|
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
|
||||||
|
|
||||||
|
# test_no_affine
|
||||||
|
bn = nn.BatchNorm1d(num_features=4, affine=False)
|
||||||
|
y = bn(x)
|
||||||
|
expected_y = mx.array(
|
||||||
|
[
|
||||||
|
[-0.439520, 1.647328, -0.955515, 1.966031],
|
||||||
|
[-1.726690, -1.449826, -0.234026, -0.723364],
|
||||||
|
[0.938414, -0.349603, -0.354470, -0.175369],
|
||||||
|
[0.305006, 0.234914, -0.393017, -0.459385],
|
||||||
|
[0.922789, -0.082813, 1.937028, -0.607913],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.assertTrue(x.shape == y.shape)
|
||||||
|
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
|
||||||
|
|
||||||
def test_conv1d(self):
|
def test_conv1d(self):
|
||||||
N = 5
|
N = 5
|
||||||
L = 12
|
L = 12
|
||||||
|
Loading…
Reference in New Issue
Block a user