mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-22 21:16:47 +08:00
update batch norm implementation
This commit is contained in:
parent
ad53687ae7
commit
e9fd1cf02d
@ -182,6 +182,12 @@ class GroupNorm(Module):
|
|||||||
return (self.weight * x + self.bias) if "weight" in self else x
|
return (self.weight * x + self.bias) if "weight" in self else x
|
||||||
|
|
||||||
|
|
||||||
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx.nn.layers.base import Module
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
class BatchNorm1d(Module):
|
class BatchNorm1d(Module):
|
||||||
r"""Applies Batch Normalization [1] to the inputs.
|
r"""Applies Batch Normalization [1] to the inputs.
|
||||||
|
|
||||||
@ -205,14 +211,6 @@ class BatchNorm1d(Module):
|
|||||||
Examples:
|
Examples:
|
||||||
>>> import mlx.core as mx
|
>>> import mlx.core as mx
|
||||||
>>> import mlx.nn as nn
|
>>> import mlx.nn as nn
|
||||||
|
|
||||||
>>> # With Learnable Parameters
|
|
||||||
>>> m = nn.BatchNorm1d(100)
|
|
||||||
>>> # Without Learnable Parameters
|
|
||||||
>>> m = nn.BatchNorm1d(4, affine=False)
|
|
||||||
>>> input = mx.random.normal(20, 4)
|
|
||||||
>>> output = m(input)
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -229,9 +227,10 @@ class BatchNorm1d(Module):
|
|||||||
|
|
||||||
self.num_features = num_features
|
self.num_features = num_features
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.momentum = momentum
|
self.momentum = mx.array([momentum])
|
||||||
self.running_mean = mx.zeros((num_features,))
|
self.running_mean = mx.zeros((num_features,))
|
||||||
self.running_var = mx.ones((num_features,))
|
self.running_var = mx.ones((num_features,))
|
||||||
|
print(self.running_mean.shape)
|
||||||
|
|
||||||
def _extra_repr(self):
|
def _extra_repr(self):
|
||||||
return f"num_features={self.num_features}, eps={self.eps}, momentum={self.momentum}, affine={'weight' in self}"
|
return f"num_features={self.num_features}, eps={self.eps}, momentum={self.momentum}, affine={'weight' in self}"
|
||||||
@ -248,10 +247,8 @@ class BatchNorm1d(Module):
|
|||||||
"""
|
"""
|
||||||
means = mx.mean(x, axis=0, keepdims=True)
|
means = mx.mean(x, axis=0, keepdims=True)
|
||||||
var = mx.var(x, axis=0, keepdims=True)
|
var = mx.var(x, axis=0, keepdims=True)
|
||||||
self.running_mean = (
|
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * means.squeeze()
|
||||||
self.momentum * self.running_mean + (1 - self.momentum) * means
|
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var.squeeze()
|
||||||
)
|
|
||||||
self.running_var = self.momentum * self.running_var + (1 - self.momentum) * var
|
|
||||||
return means, var
|
return means, var
|
||||||
|
|
||||||
def __call__(self, x: mx.array):
|
def __call__(self, x: mx.array):
|
||||||
@ -264,11 +261,10 @@ class BatchNorm1d(Module):
|
|||||||
Returns:
|
Returns:
|
||||||
mx.array: Output tensor.
|
mx.array: Output tensor.
|
||||||
"""
|
"""
|
||||||
if x.ndim != 2:
|
|
||||||
raise ValueError("BatchNorm1d only supports 2D inputs")
|
|
||||||
|
|
||||||
means, var = self.running_mean, self.running_var
|
|
||||||
if self.training:
|
if self.training:
|
||||||
means, var = self._calc_stats(x)
|
means, var = self._calc_stats(x)
|
||||||
|
else:
|
||||||
|
means, var = self.running_mean, self.running_var
|
||||||
x = (x - means) * mx.rsqrt(var + self.eps)
|
x = (x - means) * mx.rsqrt(var + self.eps)
|
||||||
return (self.weight * x + self.bias) if "weight" in self else x
|
return (self.weight * x + self.bias) if "weight" in self else x
|
Loading…
Reference in New Issue
Block a user