From d4bf9a29762ffa2258bd6add33c71124709b2a43 Mon Sep 17 00:00:00 2001 From: m0saan Date: Tue, 19 Dec 2023 06:29:52 +0100 Subject: [PATCH] calc running mean and var only when training --- python/mlx/nn/layers/normalization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mlx/nn/layers/normalization.py b/python/mlx/nn/layers/normalization.py index b29d87d30..72058aae3 100644 --- a/python/mlx/nn/layers/normalization.py +++ b/python/mlx/nn/layers/normalization.py @@ -258,7 +258,7 @@ class BatchNorm1d(Module): means = mx.mean(x, axis=(0, 2), keepdims=True) var = mx.var(x, axis=(0, 2), keepdims=True) - if self.track_running_stats: + if self.track_running_stats and self.training: self.running_mean = ( 1 - self.momentum ) * self.running_mean + self.momentum * means.squeeze()