diff --git a/python/mlx/nn/layers/normalization.py b/python/mlx/nn/layers/normalization.py index 5807d3d7a..769bab49b 100644 --- a/python/mlx/nn/layers/normalization.py +++ b/python/mlx/nn/layers/normalization.py @@ -99,7 +99,7 @@ class GroupNorm(Module): where :math:`\gamma` and :math:`\beta` are learned per feature dimension parameters initialized at 1 and 0 respectively. However, the mean and variance are computed over the spatial dimensions and each group of - features. In particular, the input is split into num_groups accross the + features. In particular, the input is split into num_groups across the feature dimension. The feature dimension is assumed to be the last dimension and the dimensions