From c3c2fcf41ddafe961d9847c5bad9e33dadc46b56 Mon Sep 17 00:00:00 2001 From: m0saan Date: Tue, 19 Dec 2023 06:27:14 +0100 Subject: [PATCH] update batch norm implementation -> fixed some bug and added support for 3D inputs --- python/mlx/nn/layers/normalization.py | 65 ++++++++++++++++++--------- 1 file changed, 44 insertions(+), 21 deletions(-) diff --git a/python/mlx/nn/layers/normalization.py b/python/mlx/nn/layers/normalization.py index 6b1b53e06..b29d87d30 100644 --- a/python/mlx/nn/layers/normalization.py +++ b/python/mlx/nn/layers/normalization.py @@ -184,12 +184,14 @@ class GroupNorm(Module): # Copyright © 2023 Apple Inc. -import mlx.core as mx -from mlx.nn.layers.base import Module from typing import Tuple +import mlx.core as mx +from mlx.nn.layers.base import Module + + class BatchNorm1d(Module): - r"""Applies Batch Normalization [1] to the inputs. + r"""Applies Batch Normalization over a 2D or 3D input. Computes @@ -209,8 +211,7 @@ class BatchNorm1d(Module): affine (bool, optional): If True, learn an affine transform to apply after the normalization. Default is True. Examples: - >>> import mlx.core as mx - >>> import mlx.nn as nn + -> TODO: Add examples. """ def __init__( @@ -219,21 +220,25 @@ class BatchNorm1d(Module): eps: float = 1e-5, momentum: float = 0.1, affine: bool = True, + track_running_stats: bool = True, ): super().__init__() - if affine: - self.bias = mx.zeros((num_features,)) - self.weight = mx.ones((num_features,)) - self.num_features = num_features self.eps = eps - self.momentum = mx.array([momentum]) - self.running_mean = mx.zeros((num_features,)) - self.running_var = mx.ones((num_features,)) - print(self.running_mean.shape) + self.momentum = momentum + self.affine = affine + self.track_running_stats = track_running_stats + + if self.affine: + self.weight = mx.ones((num_features,)) + self.bias = mx.zeros((num_features,)) + + if self.track_running_stats: + self.running_mean = mx.zeros((num_features,)) + self.running_var = mx.ones((num_features,)) def _extra_repr(self): - return f"num_features={self.num_features}, eps={self.eps}, momentum={self.momentum}, affine={'weight' in self}" + return f"{self.num_features}, eps={self.eps}, momentum={self.momentum}, affine={'weight' in self}, track_running_stats={self.track_running_stats}" def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]: """ @@ -245,10 +250,21 @@ class BatchNorm1d(Module): Returns: tuple: Tuple containing mean and variance. """ - means = mx.mean(x, axis=0, keepdims=True) - var = mx.var(x, axis=0, keepdims=True) - self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * means.squeeze() - self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var.squeeze() + + if len(x.shape) == 2: + means = mx.mean(x, axis=0, keepdims=True) + var = mx.var(x, axis=0, keepdims=True) + else: + means = mx.mean(x, axis=(0, 2), keepdims=True) + var = mx.var(x, axis=(0, 2), keepdims=True) + + if self.track_running_stats: + self.running_mean = ( + 1 - self.momentum + ) * self.running_mean + self.momentum * means.squeeze() + self.running_var = ( + 1 - self.momentum + ) * self.running_var + self.momentum * var.squeeze() return means, var def __call__(self, x: mx.array): @@ -261,10 +277,17 @@ class BatchNorm1d(Module): Returns: mx.array: Output tensor. """ - - if self.training: + + if x.ndim != 2 and x.ndim != 3: + raise ValueError(f"expected 2D or 3D input (got {x.ndim}D input)") + + if x.ndim == 3: + self.weight = mx.expand_dims(self.weight, [0, 2]) + self.bias = mx.expand_dims(self.bias, [0, 2]) + + if self.training or not self.track_running_stats: means, var = self._calc_stats(x) else: means, var = self.running_mean, self.running_var x = (x - means) * mx.rsqrt(var + self.eps) - return (self.weight * x + self.bias) if "weight" in self else x \ No newline at end of file + return (self.weight * x + self.bias) if "weight" in self else x