mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
faster group norm (#1304)
This commit is contained in:
parent
43ffdab172
commit
6c8dd307eb
@ -199,18 +199,17 @@ class GroupNorm(Module):
|
||||
def _pytorch_compatible_group_norm(self, x):
|
||||
num_groups = self.num_groups
|
||||
batch, *rest, dims = x.shape
|
||||
group_size = dims // num_groups
|
||||
|
||||
# Split into groups
|
||||
x = x.reshape(batch, -1, num_groups, dims // num_groups)
|
||||
x = x.transpose(0, 1, 3, 2).reshape(batch, -1, num_groups)
|
||||
x = x.reshape(batch, -1, num_groups, group_size)
|
||||
x = x.transpose(0, 2, 1, 3).reshape(batch, num_groups, -1)
|
||||
|
||||
# Normalize
|
||||
means = mx.mean(x, axis=1, keepdims=True)
|
||||
var = mx.var(x, axis=1, keepdims=True)
|
||||
x = (x - means) * mx.rsqrt(var + self.eps)
|
||||
x = x.reshape(batch, -1, dims // num_groups, num_groups)
|
||||
x = x.transpose(0, 1, 3, 2).reshape(batch, *rest, dims)
|
||||
x = mx.fast.layer_norm(x, eps=self.eps, weight=None, bias=None)
|
||||
|
||||
x = x.reshape(batch, num_groups, -1, group_size)
|
||||
x = x.transpose(0, 2, 1, 3).reshape(batch, *rest, dims)
|
||||
return x
|
||||
|
||||
def _group_norm(self, x):
|
||||
|
Loading…
Reference in New Issue
Block a user