update batch norm implementation

This commit is contained in:
m0saan 2023-12-19 04:41:34 +01:00
parent ad53687ae7
commit e9fd1cf02d

View File

@ -182,6 +182,12 @@ class GroupNorm(Module):
return (self.weight * x + self.bias) if "weight" in self else x return (self.weight * x + self.bias) if "weight" in self else x
# Copyright © 2023 Apple Inc.
import mlx.core as mx
from mlx.nn.layers.base import Module
from typing import Tuple
class BatchNorm1d(Module): class BatchNorm1d(Module):
r"""Applies Batch Normalization [1] to the inputs. r"""Applies Batch Normalization [1] to the inputs.
@ -205,14 +211,6 @@ class BatchNorm1d(Module):
Examples: Examples:
>>> import mlx.core as mx >>> import mlx.core as mx
>>> import mlx.nn as nn >>> import mlx.nn as nn
>>> # With Learnable Parameters
>>> m = nn.BatchNorm1d(100)
>>> # Without Learnable Parameters
>>> m = nn.BatchNorm1d(4, affine=False)
>>> input = mx.random.normal(20, 4)
>>> output = m(input)
""" """
def __init__( def __init__(
@ -229,9 +227,10 @@ class BatchNorm1d(Module):
self.num_features = num_features self.num_features = num_features
self.eps = eps self.eps = eps
self.momentum = momentum self.momentum = mx.array([momentum])
self.running_mean = mx.zeros((num_features,)) self.running_mean = mx.zeros((num_features,))
self.running_var = mx.ones((num_features,)) self.running_var = mx.ones((num_features,))
print(self.running_mean.shape)
def _extra_repr(self): def _extra_repr(self):
return f"num_features={self.num_features}, eps={self.eps}, momentum={self.momentum}, affine={'weight' in self}" return f"num_features={self.num_features}, eps={self.eps}, momentum={self.momentum}, affine={'weight' in self}"
@ -248,10 +247,8 @@ class BatchNorm1d(Module):
""" """
means = mx.mean(x, axis=0, keepdims=True) means = mx.mean(x, axis=0, keepdims=True)
var = mx.var(x, axis=0, keepdims=True) var = mx.var(x, axis=0, keepdims=True)
self.running_mean = ( self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * means.squeeze()
self.momentum * self.running_mean + (1 - self.momentum) * means self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var.squeeze()
)
self.running_var = self.momentum * self.running_var + (1 - self.momentum) * var
return means, var return means, var
def __call__(self, x: mx.array): def __call__(self, x: mx.array):
@ -264,11 +261,10 @@ class BatchNorm1d(Module):
Returns: Returns:
mx.array: Output tensor. mx.array: Output tensor.
""" """
if x.ndim != 2:
raise ValueError("BatchNorm1d only supports 2D inputs")
means, var = self.running_mean, self.running_var
if self.training: if self.training:
means, var = self._calc_stats(x) means, var = self._calc_stats(x)
else:
means, var = self.running_mean, self.running_var
x = (x - means) * mx.rsqrt(var + self.eps) x = (x - means) * mx.rsqrt(var + self.eps)
return (self.weight * x + self.bias) if "weight" in self else x return (self.weight * x + self.bias) if "weight" in self else x