mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	faster group norm (#1304)
This commit is contained in:
		| @@ -199,18 +199,17 @@ class GroupNorm(Module): | ||||
|     def _pytorch_compatible_group_norm(self, x): | ||||
|         num_groups = self.num_groups | ||||
|         batch, *rest, dims = x.shape | ||||
|         group_size = dims // num_groups | ||||
|  | ||||
|         # Split into groups | ||||
|         x = x.reshape(batch, -1, num_groups, dims // num_groups) | ||||
|         x = x.transpose(0, 1, 3, 2).reshape(batch, -1, num_groups) | ||||
|         x = x.reshape(batch, -1, num_groups, group_size) | ||||
|         x = x.transpose(0, 2, 1, 3).reshape(batch, num_groups, -1) | ||||
|  | ||||
|         # Normalize | ||||
|         means = mx.mean(x, axis=1, keepdims=True) | ||||
|         var = mx.var(x, axis=1, keepdims=True) | ||||
|         x = (x - means) * mx.rsqrt(var + self.eps) | ||||
|         x = x.reshape(batch, -1, dims // num_groups, num_groups) | ||||
|         x = x.transpose(0, 1, 3, 2).reshape(batch, *rest, dims) | ||||
|         x = mx.fast.layer_norm(x, eps=self.eps, weight=None, bias=None) | ||||
|  | ||||
|         x = x.reshape(batch, num_groups, -1, group_size) | ||||
|         x = x.transpose(0, 2, 1, 3).reshape(batch, *rest, dims) | ||||
|         return x | ||||
|  | ||||
|     def _group_norm(self, x): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun