From e9fd1cf02d5913af6b4f4f0ffb291d8c5913ff21 Mon Sep 17 00:00:00 2001 From: m0saan Date: Tue, 19 Dec 2023 04:41:34 +0100 Subject: [PATCH] update batch norm implementation --- python/mlx/nn/layers/normalization.py | 32 ++++++++++++--------------- 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/python/mlx/nn/layers/normalization.py b/python/mlx/nn/layers/normalization.py index 769bab49b..6b1b53e06 100644 --- a/python/mlx/nn/layers/normalization.py +++ b/python/mlx/nn/layers/normalization.py @@ -182,6 +182,12 @@ class GroupNorm(Module): return (self.weight * x + self.bias) if "weight" in self else x +# Copyright © 2023 Apple Inc. + +import mlx.core as mx +from mlx.nn.layers.base import Module +from typing import Tuple + class BatchNorm1d(Module): r"""Applies Batch Normalization [1] to the inputs. @@ -205,14 +211,6 @@ class BatchNorm1d(Module): Examples: >>> import mlx.core as mx >>> import mlx.nn as nn - - >>> # With Learnable Parameters - >>> m = nn.BatchNorm1d(100) - >>> # Without Learnable Parameters - >>> m = nn.BatchNorm1d(4, affine=False) - >>> input = mx.random.normal(20, 4) - >>> output = m(input) - """ def __init__( @@ -229,9 +227,10 @@ class BatchNorm1d(Module): self.num_features = num_features self.eps = eps - self.momentum = momentum + self.momentum = mx.array([momentum]) self.running_mean = mx.zeros((num_features,)) self.running_var = mx.ones((num_features,)) + print(self.running_mean.shape) def _extra_repr(self): return f"num_features={self.num_features}, eps={self.eps}, momentum={self.momentum}, affine={'weight' in self}" @@ -248,10 +247,8 @@ class BatchNorm1d(Module): """ means = mx.mean(x, axis=0, keepdims=True) var = mx.var(x, axis=0, keepdims=True) - self.running_mean = ( - self.momentum * self.running_mean + (1 - self.momentum) * means - ) - self.running_var = self.momentum * self.running_var + (1 - self.momentum) * var + self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * means.squeeze() + self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var.squeeze() return means, var def __call__(self, x: mx.array): @@ -264,11 +261,10 @@ class BatchNorm1d(Module): Returns: mx.array: Output tensor. """ - if x.ndim != 2: - raise ValueError("BatchNorm1d only supports 2D inputs") - - means, var = self.running_mean, self.running_var + if self.training: means, var = self._calc_stats(x) + else: + means, var = self.running_mean, self.running_var x = (x - means) * mx.rsqrt(var + self.eps) - return (self.weight * x + self.bias) if "weight" in self else x + return (self.weight * x + self.bias) if "weight" in self else x \ No newline at end of file