diff --git a/python/mlx/nn/layers/normalization.py b/python/mlx/nn/layers/normalization.py index 588c2ed2b..bdddb6ccf 100644 --- a/python/mlx/nn/layers/normalization.py +++ b/python/mlx/nn/layers/normalization.py @@ -199,18 +199,17 @@ class GroupNorm(Module): def _pytorch_compatible_group_norm(self, x): num_groups = self.num_groups batch, *rest, dims = x.shape + group_size = dims // num_groups # Split into groups - x = x.reshape(batch, -1, num_groups, dims // num_groups) - x = x.transpose(0, 1, 3, 2).reshape(batch, -1, num_groups) + x = x.reshape(batch, -1, num_groups, group_size) + x = x.transpose(0, 2, 1, 3).reshape(batch, num_groups, -1) # Normalize - means = mx.mean(x, axis=1, keepdims=True) - var = mx.var(x, axis=1, keepdims=True) - x = (x - means) * mx.rsqrt(var + self.eps) - x = x.reshape(batch, -1, dims // num_groups, num_groups) - x = x.transpose(0, 1, 3, 2).reshape(batch, *rest, dims) + x = mx.fast.layer_norm(x, eps=self.eps, weight=None, bias=None) + x = x.reshape(batch, num_groups, -1, group_size) + x = x.transpose(0, 2, 1, 3).reshape(batch, *rest, dims) return x def _group_norm(self, x):