mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
calc running mean and var only when training
This commit is contained in:
parent
c3c2fcf41d
commit
d4bf9a2976
@ -258,7 +258,7 @@ class BatchNorm1d(Module):
|
|||||||
means = mx.mean(x, axis=(0, 2), keepdims=True)
|
means = mx.mean(x, axis=(0, 2), keepdims=True)
|
||||||
var = mx.var(x, axis=(0, 2), keepdims=True)
|
var = mx.var(x, axis=(0, 2), keepdims=True)
|
||||||
|
|
||||||
if self.track_running_stats:
|
if self.track_running_stats and self.training:
|
||||||
self.running_mean = (
|
self.running_mean = (
|
||||||
1 - self.momentum
|
1 - self.momentum
|
||||||
) * self.running_mean + self.momentum * means.squeeze()
|
) * self.running_mean + self.momentum * means.squeeze()
|
||||||
|
Loading…
Reference in New Issue
Block a user