From 6c8dd307eb8aaecd782596e43b49ff7cdb126abf Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 1 Aug 2024 12:49:23 -0700 Subject: [PATCH] faster group norm (#1304) --- python/mlx/nn/layers/normalization.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/python/mlx/nn/layers/normalization.py b/python/mlx/nn/layers/normalization.py index 588c2ed2b..bdddb6ccf 100644 --- a/python/mlx/nn/layers/normalization.py +++ b/python/mlx/nn/layers/normalization.py @@ -199,18 +199,17 @@ class GroupNorm(Module): def _pytorch_compatible_group_norm(self, x): num_groups = self.num_groups batch, *rest, dims = x.shape + group_size = dims // num_groups # Split into groups - x = x.reshape(batch, -1, num_groups, dims // num_groups) - x = x.transpose(0, 1, 3, 2).reshape(batch, -1, num_groups) + x = x.reshape(batch, -1, num_groups, group_size) + x = x.transpose(0, 2, 1, 3).reshape(batch, num_groups, -1) # Normalize - means = mx.mean(x, axis=1, keepdims=True) - var = mx.var(x, axis=1, keepdims=True) - x = (x - means) * mx.rsqrt(var + self.eps) - x = x.reshape(batch, -1, dims // num_groups, num_groups) - x = x.transpose(0, 1, 3, 2).reshape(batch, *rest, dims) + x = mx.fast.layer_norm(x, eps=self.eps, weight=None, bias=None) + x = x.reshape(batch, num_groups, -1, group_size) + x = x.transpose(0, 2, 1, 3).reshape(batch, *rest, dims) return x def _group_norm(self, x):