mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
Update python/mlx/nn/layers/normalization.py
Co-authored-by: Robert McCraith <mccraithrobert@gmail.com>
This commit is contained in:
parent
019a85511c
commit
c68a472b83
@ -234,8 +234,8 @@ class BatchNorm(Module):
|
|||||||
self.bias = mx.zeros((num_features,))
|
self.bias = mx.zeros((num_features,))
|
||||||
|
|
||||||
if self.track_running_stats:
|
if self.track_running_stats:
|
||||||
self.running_mean = mx.zeros((num_features,))
|
self._running_mean = mx.zeros((num_features,))
|
||||||
self.running_var = mx.ones((num_features,))
|
self._running_var = mx.ones((num_features,))
|
||||||
|
|
||||||
def _extra_repr(self):
|
def _extra_repr(self):
|
||||||
return f"{self.num_features}, eps={self.eps}, momentum={self.momentum}, affine={'weight' in self}, track_running_stats={self.track_running_stats}"
|
return f"{self.num_features}, eps={self.eps}, momentum={self.momentum}, affine={'weight' in self}, track_running_stats={self.track_running_stats}"
|
||||||
|
Loading…
Reference in New Issue
Block a user