mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	
		
			
	
	
		
			42 lines
		
	
	
		
			1.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			42 lines
		
	
	
		
			1.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| 
								 | 
							
								# Copyright © 2023-2024 Apple Inc.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								import mlx.core as mx
							 | 
						||
| 
								 | 
							
								import mlx.nn as nn
							 | 
						||
| 
								 | 
							
								from time_utils import time_fn
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def layer_norm(x, w, b, eps):
							 | 
						||
| 
								 | 
							
								    ot = x.dtype
							 | 
						||
| 
								 | 
							
								    x = x.astype(mx.float32)
							 | 
						||
| 
								 | 
							
								    mu = mx.mean(x, -1, keepdims=True)
							 | 
						||
| 
								 | 
							
								    v = mx.var(x, -1, keepdims=True)
							 | 
						||
| 
								 | 
							
								    return (x - mu) * mx.rsqrt(v + eps) * w + b
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def time_layer_norm():
							 | 
						||
| 
								 | 
							
								    f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()
							 | 
						||
| 
								 | 
							
								    f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()
							 | 
						||
| 
								 | 
							
								    g1 = mx.grad(f1, argnums=(0, 1, 2))
							 | 
						||
| 
								 | 
							
								    g2 = mx.grad(f2, argnums=(0, 1, 2))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
							 | 
						||
| 
								 | 
							
								    w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
							 | 
						||
| 
								 | 
							
								    b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
							 | 
						||
| 
								 | 
							
								    y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
							 | 
						||
| 
								 | 
							
								    mx.eval(x, w, b, y)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def layer_norm_loop(g, x, w, b):
							 | 
						||
| 
								 | 
							
								        gx, gw, gb = x, w, b
							 | 
						||
| 
								 | 
							
								        for _ in range(32):
							 | 
						||
| 
								 | 
							
								            gx, gw, gb = g(gx, gw, gb, y)
							 | 
						||
| 
								 | 
							
								        return gx, gw, gb
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    time_fn(layer_norm_loop, g1, x, w, b)
							 | 
						||
| 
								 | 
							
								    time_fn(layer_norm_loop, g2, x, w, b)
							 | 
						||
| 
								 | 
							
								    time_fn(layer_norm_loop, mx.compile(g1), x, w, b)
							 | 
						||
| 
								 | 
							
								    time_fn(layer_norm_loop, mx.compile(g2), x, w, b)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								if __name__ == "__main__":
							 | 
						||
| 
								 | 
							
								    time_layer_norm()
							 |