mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-22 13:07:51 +08:00
Update python/mlx/nn/layers/normalization.py
Update BatchNorm to support NLC and NHWC input formats In our convolution operations, we follow the convention that the channels are the last dimension. This commit updates the BatchNorm implementation to support inputs where the channels are the last dimension (NLC or NHWC). This involves changing the dimensions we normalize over and the dimensions we expand our parameters over. Co-authored-by: Robert McCraith <mccraithrobert@gmail.com>
This commit is contained in:
parent
cf5a5a4a01
commit
28009c9cdb
@ -254,8 +254,8 @@ class BatchNorm(Module):
|
|||||||
num_dims = len(x.shape)
|
num_dims = len(x.shape)
|
||||||
dims_dict = {
|
dims_dict = {
|
||||||
2: ((1, self.num_features), (0,)),
|
2: ((1, self.num_features), (0,)),
|
||||||
3: ((1, self.num_features, 1), (0, 2)),
|
3: ((1, 1, self.num_features), (0, 1)),
|
||||||
4: ((1, self.num_features, 1, 1), (0, 2, 3)),
|
4: ((1, 1, 1, self.num_features), (0, 1, 2)),
|
||||||
}
|
}
|
||||||
|
|
||||||
if num_dims not in dims_dict:
|
if num_dims not in dims_dict:
|
||||||
|
Loading…
Reference in New Issue
Block a user