diff --git a/python/mlx/nn/layers/normalization.py b/python/mlx/nn/layers/normalization.py index b29d87d30..72058aae3 100644 --- a/python/mlx/nn/layers/normalization.py +++ b/python/mlx/nn/layers/normalization.py @@ -258,7 +258,7 @@ class BatchNorm1d(Module): means = mx.mean(x, axis=(0, 2), keepdims=True) var = mx.var(x, axis=(0, 2), keepdims=True) - if self.track_running_stats: + if self.track_running_stats and self.training: self.running_mean = ( 1 - self.momentum ) * self.running_mean + self.momentum * means.squeeze()