From 28009c9cdb8b31694724641230003ccf4468b2bb Mon Sep 17 00:00:00 2001 From: __mo_san__ <50895527+m0saan@users.noreply.github.com> Date: Sun, 24 Dec 2023 22:22:26 +0100 Subject: [PATCH] Update python/mlx/nn/layers/normalization.py Update BatchNorm to support NLC and NHWC input formats In our convolution operations, we follow the convention that the channels are the last dimension. This commit updates the BatchNorm implementation to support inputs where the channels are the last dimension (NLC or NHWC). This involves changing the dimensions we normalize over and the dimensions we expand our parameters over. Co-authored-by: Robert McCraith --- python/mlx/nn/layers/normalization.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/mlx/nn/layers/normalization.py b/python/mlx/nn/layers/normalization.py index b3c1ceefb..e05633347 100644 --- a/python/mlx/nn/layers/normalization.py +++ b/python/mlx/nn/layers/normalization.py @@ -253,9 +253,9 @@ class BatchNorm(Module): num_dims = len(x.shape) dims_dict = { - 2: ((1, self.num_features), (0,)), - 3: ((1, self.num_features, 1), (0, 2)), - 4: ((1, self.num_features, 1, 1), (0, 2, 3)), + 2: ((1, self.num_features), (0,)), + 3: ((1, 1, self.num_features), (0, 1)), + 4: ((1, 1, 1, self.num_features), (0, 1, 2)), } if num_dims not in dims_dict: