mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 19:38:16 +08:00
Fix BN stats to not expand shape (#409)
* fix BN stats to not expand shape * nit
This commit is contained in:
@@ -333,8 +333,8 @@ class BatchNorm(Module):
|
||||
"""
|
||||
reduction_axes = tuple(range(0, x.ndim - 1))
|
||||
|
||||
mean = mx.mean(x, axis=reduction_axes, keepdims=True)
|
||||
var = mx.var(x, axis=reduction_axes, keepdims=True)
|
||||
mean = mx.mean(x, axis=reduction_axes)
|
||||
var = mx.var(x, axis=reduction_axes)
|
||||
|
||||
return mean, var
|
||||
|
||||
|
Reference in New Issue
Block a user