faster group norm (#1304)

This commit is contained in:
Awni Hannun 2024-08-01 12:49:23 -07:00 committed by GitHub
parent 43ffdab172
commit 6c8dd307eb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -199,18 +199,17 @@ class GroupNorm(Module):
def _pytorch_compatible_group_norm(self, x):
num_groups = self.num_groups
batch, *rest, dims = x.shape
group_size = dims // num_groups
# Split into groups
x = x.reshape(batch, -1, num_groups, dims // num_groups)
x = x.transpose(0, 1, 3, 2).reshape(batch, -1, num_groups)
x = x.reshape(batch, -1, num_groups, group_size)
x = x.transpose(0, 2, 1, 3).reshape(batch, num_groups, -1)
# Normalize
means = mx.mean(x, axis=1, keepdims=True)
var = mx.var(x, axis=1, keepdims=True)
x = (x - means) * mx.rsqrt(var + self.eps)
x = x.reshape(batch, -1, dims // num_groups, num_groups)
x = x.transpose(0, 1, 3, 2).reshape(batch, *rest, dims)
x = mx.fast.layer_norm(x, eps=self.eps, weight=None, bias=None)
x = x.reshape(batch, num_groups, -1, group_size)
x = x.transpose(0, 2, 1, 3).reshape(batch, *rest, dims)
return x
def _group_norm(self, x):