Fix BN stats to not expand shape (#409)

* fix BN stats to not expand shape

* nit
This commit is contained in:
Awni Hannun
2024-01-09 11:54:51 -08:00
committed by GitHub
parent 753867123d
commit e9ca65c939
2 changed files with 17 additions and 16 deletions

View File

@@ -333,8 +333,8 @@ class BatchNorm(Module):
"""
reduction_axes = tuple(range(0, x.ndim - 1))
mean = mx.mean(x, axis=reduction_axes, keepdims=True)
var = mx.var(x, axis=reduction_axes, keepdims=True)
mean = mx.mean(x, axis=reduction_axes)
var = mx.var(x, axis=reduction_axes)
return mean, var