mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 12:06:42 +08:00
calc running mean and var only when training
This commit is contained in:
parent
c3c2fcf41d
commit
d4bf9a2976
@ -258,7 +258,7 @@ class BatchNorm1d(Module):
|
||||
means = mx.mean(x, axis=(0, 2), keepdims=True)
|
||||
var = mx.var(x, axis=(0, 2), keepdims=True)
|
||||
|
||||
if self.track_running_stats:
|
||||
if self.track_running_stats and self.training:
|
||||
self.running_mean = (
|
||||
1 - self.momentum
|
||||
) * self.running_mean + self.momentum * means.squeeze()
|
||||
|
Loading…
Reference in New Issue
Block a user