mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 02:28:13 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			64 lines
		
	
	
		
			1.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			64 lines
		
	
	
		
			1.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Copyright © 2023-2024 Apple Inc.
 | 
						|
 | 
						|
import mlx.core as mx
 | 
						|
import mlx.nn as nn
 | 
						|
from time_utils import time_fn
 | 
						|
 | 
						|
 | 
						|
def rms_norm(x, w, eps):
 | 
						|
    ot = x.dtype
 | 
						|
    x = x.astype(mx.float32)
 | 
						|
    n = mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
 | 
						|
    y = (x * n).astype(ot)
 | 
						|
    if w is not None:
 | 
						|
        y = y * w
 | 
						|
    return y
 | 
						|
 | 
						|
 | 
						|
def time_rms_norm():
 | 
						|
    f1 = lambda x, w, y: (rms_norm(x, w, 1e-5) * y).sum()
 | 
						|
    f2 = lambda x, w, y: (mx.fast.rms_norm(x, w, 1e-5) * y).sum()
 | 
						|
    g1 = mx.grad(f1, argnums=(0, 1))
 | 
						|
    g2 = mx.grad(f2, argnums=(0, 1))
 | 
						|
 | 
						|
    x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
 | 
						|
    w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
 | 
						|
    y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
 | 
						|
    mx.eval(x, w, y)
 | 
						|
 | 
						|
    def rms_norm_loop(g, x, w):
 | 
						|
        gx, gw = x, w
 | 
						|
        for _ in range(32):
 | 
						|
            gx, gw = g(gx, gw, y)
 | 
						|
        return gx, gw
 | 
						|
 | 
						|
    time_fn(rms_norm_loop, g1, x, w)
 | 
						|
    time_fn(rms_norm_loop, g2, x, w)
 | 
						|
    time_fn(rms_norm_loop, mx.compile(g1), x, w)
 | 
						|
    time_fn(rms_norm_loop, mx.compile(g2), x, w)
 | 
						|
 | 
						|
    f1 = lambda x, y: (rms_norm(x, None, 1e-5) * y).sum()
 | 
						|
    f2 = lambda x, y: (mx.fast.rms_norm(x, None, 1e-5) * y).sum()
 | 
						|
    g1 = mx.grad(f1, argnums=(0,))
 | 
						|
    g2 = mx.grad(f2, argnums=(0,))
 | 
						|
 | 
						|
    x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
 | 
						|
    w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
 | 
						|
    y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
 | 
						|
    mx.eval(x, w, y)
 | 
						|
 | 
						|
    def rms_norm_loop(g, x):
 | 
						|
        gx = x
 | 
						|
        for _ in range(32):
 | 
						|
            gx = g(gx, y)
 | 
						|
        return gx
 | 
						|
 | 
						|
    time_fn(rms_norm_loop, g1, x)
 | 
						|
    time_fn(rms_norm_loop, g2, x)
 | 
						|
    time_fn(rms_norm_loop, mx.compile(g1), x)
 | 
						|
    time_fn(rms_norm_loop, mx.compile(g2), x)
 | 
						|
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    time_rms_norm()
 |