diff --git a/python/mlx/nn/layers/normalization.py b/python/mlx/nn/layers/normalization.py index b3c1ceefb..e05633347 100644 --- a/python/mlx/nn/layers/normalization.py +++ b/python/mlx/nn/layers/normalization.py @@ -253,9 +253,9 @@ class BatchNorm(Module): num_dims = len(x.shape) dims_dict = { - 2: ((1, self.num_features), (0,)), - 3: ((1, self.num_features, 1), (0, 2)), - 4: ((1, self.num_features, 1, 1), (0, 2, 3)), + 2: ((1, self.num_features), (0,)), + 3: ((1, 1, self.num_features), (0, 1)), + 4: ((1, 1, 1, self.num_features), (0, 1, 2)), } if num_dims not in dims_dict: