Update python/mlx/nn/layers/normalization.py

Update BatchNorm to support NLC and NHWC input formats

In our convolution operations, we follow the convention that the channels are the last dimension. This commit updates the BatchNorm implementation to support inputs where the channels are the last dimension (NLC or NHWC). This involves changing the dimensions we normalize over and the dimensions we expand our parameters over.

Co-authored-by: Robert McCraith <mccraithrobert@gmail.com>
This commit is contained in:
__mo_san__ 2023-12-24 22:22:26 +01:00 committed by m0saan
parent cf5a5a4a01
commit 28009c9cdb

View File

@ -254,8 +254,8 @@ class BatchNorm(Module):
num_dims = len(x.shape) num_dims = len(x.shape)
dims_dict = { dims_dict = {
2: ((1, self.num_features), (0,)), 2: ((1, self.num_features), (0,)),
3: ((1, self.num_features, 1), (0, 2)), 3: ((1, 1, self.num_features), (0, 1)),
4: ((1, self.num_features, 1, 1), (0, 2, 3)), 4: ((1, 1, 1, self.num_features), (0, 1, 2)),
} }
if num_dims not in dims_dict: if num_dims not in dims_dict: