update batch norm implementation -> fixed some bug and added support for 3D inputs

This commit is contained in:
m0saan 2023-12-19 06:27:14 +01:00
parent e9fd1cf02d
commit c3c2fcf41d

View File

@ -184,12 +184,14 @@ class GroupNorm(Module):
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import mlx.core as mx
from mlx.nn.layers.base import Module
from typing import Tuple from typing import Tuple
import mlx.core as mx
from mlx.nn.layers.base import Module
class BatchNorm1d(Module): class BatchNorm1d(Module):
r"""Applies Batch Normalization [1] to the inputs. r"""Applies Batch Normalization over a 2D or 3D input.
Computes Computes
@ -209,8 +211,7 @@ class BatchNorm1d(Module):
affine (bool, optional): If True, learn an affine transform to apply after the normalization. Default is True. affine (bool, optional): If True, learn an affine transform to apply after the normalization. Default is True.
Examples: Examples:
>>> import mlx.core as mx -> TODO: Add examples.
>>> import mlx.nn as nn
""" """
def __init__( def __init__(
@ -219,21 +220,25 @@ class BatchNorm1d(Module):
eps: float = 1e-5, eps: float = 1e-5,
momentum: float = 0.1, momentum: float = 0.1,
affine: bool = True, affine: bool = True,
track_running_stats: bool = True,
): ):
super().__init__() super().__init__()
if affine:
self.bias = mx.zeros((num_features,))
self.weight = mx.ones((num_features,))
self.num_features = num_features self.num_features = num_features
self.eps = eps self.eps = eps
self.momentum = mx.array([momentum]) self.momentum = momentum
self.running_mean = mx.zeros((num_features,)) self.affine = affine
self.running_var = mx.ones((num_features,)) self.track_running_stats = track_running_stats
print(self.running_mean.shape)
if self.affine:
self.weight = mx.ones((num_features,))
self.bias = mx.zeros((num_features,))
if self.track_running_stats:
self.running_mean = mx.zeros((num_features,))
self.running_var = mx.ones((num_features,))
def _extra_repr(self): def _extra_repr(self):
return f"num_features={self.num_features}, eps={self.eps}, momentum={self.momentum}, affine={'weight' in self}" return f"{self.num_features}, eps={self.eps}, momentum={self.momentum}, affine={'weight' in self}, track_running_stats={self.track_running_stats}"
def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]: def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]:
""" """
@ -245,10 +250,21 @@ class BatchNorm1d(Module):
Returns: Returns:
tuple: Tuple containing mean and variance. tuple: Tuple containing mean and variance.
""" """
means = mx.mean(x, axis=0, keepdims=True)
var = mx.var(x, axis=0, keepdims=True) if len(x.shape) == 2:
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * means.squeeze() means = mx.mean(x, axis=0, keepdims=True)
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var.squeeze() var = mx.var(x, axis=0, keepdims=True)
else:
means = mx.mean(x, axis=(0, 2), keepdims=True)
var = mx.var(x, axis=(0, 2), keepdims=True)
if self.track_running_stats:
self.running_mean = (
1 - self.momentum
) * self.running_mean + self.momentum * means.squeeze()
self.running_var = (
1 - self.momentum
) * self.running_var + self.momentum * var.squeeze()
return means, var return means, var
def __call__(self, x: mx.array): def __call__(self, x: mx.array):
@ -262,7 +278,14 @@ class BatchNorm1d(Module):
mx.array: Output tensor. mx.array: Output tensor.
""" """
if self.training: if x.ndim != 2 and x.ndim != 3:
raise ValueError(f"expected 2D or 3D input (got {x.ndim}D input)")
if x.ndim == 3:
self.weight = mx.expand_dims(self.weight, [0, 2])
self.bias = mx.expand_dims(self.bias, [0, 2])
if self.training or not self.track_running_stats:
means, var = self._calc_stats(x) means, var = self._calc_stats(x)
else: else:
means, var = self.running_mean, self.running_var means, var = self.running_mean, self.running_var