mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 02:28:13 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			36 lines
		
	
	
		
			637 B
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			36 lines
		
	
	
		
			637 B
		
	
	
	
		
			Python
		
	
	
	
	
	
# Copyright © 2023-2024 Apple Inc.
 | 
						|
 | 
						|
import mlx.core as mx
 | 
						|
import mlx.nn as nn
 | 
						|
from time_utils import time_fn
 | 
						|
 | 
						|
 | 
						|
def time_rope():
 | 
						|
    rope = nn.RoPE(64)
 | 
						|
 | 
						|
    # vec
 | 
						|
    x = mx.random.uniform(shape=(1, 32, 1, 128)).astype(mx.float16)
 | 
						|
    mx.eval(x)
 | 
						|
 | 
						|
    def rope_vec(x):
 | 
						|
        for _ in range(32):
 | 
						|
            x = rope(x, offset=100)
 | 
						|
        return x
 | 
						|
 | 
						|
    time_fn(rope_vec, x)
 | 
						|
 | 
						|
    # matrix
 | 
						|
    x = mx.random.uniform(shape=(1, 32, 1024, 128)).astype(mx.float16)
 | 
						|
    mx.eval(x)
 | 
						|
 | 
						|
    def rope_mat(x):
 | 
						|
        for _ in range(32):
 | 
						|
            x = rope(x)
 | 
						|
        return x
 | 
						|
 | 
						|
    time_fn(rope_mat, x)
 | 
						|
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    time_rope()
 |