calc running mean and var only when training

This commit is contained in:
m0saan 2023-12-19 06:29:52 +01:00
parent c3c2fcf41d
commit d4bf9a2976

View File

@ -258,7 +258,7 @@ class BatchNorm1d(Module):
means = mx.mean(x, axis=(0, 2), keepdims=True)
var = mx.var(x, axis=(0, 2), keepdims=True)
if self.track_running_stats:
if self.track_running_stats and self.training:
self.running_mean = (
1 - self.momentum
) * self.running_mean + self.momentum * means.squeeze()