mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Fix copy donation and add partial rope (#881)
This commit is contained in:
		 Angelos Katharopoulos
					Angelos Katharopoulos
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							8e5a5a1ccd
						
					
				
				
					commit
					6ee1112f30
				
			| @@ -6,21 +6,21 @@ from time_utils import time_fn | ||||
|  | ||||
|  | ||||
| def time_rope(): | ||||
|     rope = nn.RoPE(4096) | ||||
|     rope = nn.RoPE(64) | ||||
|  | ||||
|     # vec | ||||
|     x = mx.random.uniform(shape=(1, 4096)).astype(mx.float16) | ||||
|     x = mx.random.uniform(shape=(1, 32, 1, 128)).astype(mx.float16) | ||||
|     mx.eval(x) | ||||
|  | ||||
|     def rope_vec(x): | ||||
|         for _ in range(32): | ||||
|             x = rope(x) | ||||
|             x = rope(x, offset=100) | ||||
|         return x | ||||
|  | ||||
|     time_fn(rope_vec, x) | ||||
|  | ||||
|     # matrix | ||||
|     x = mx.random.uniform(shape=(1024, 4096)).astype(mx.float16) | ||||
|     x = mx.random.uniform(shape=(1, 32, 1024, 128)).astype(mx.float16) | ||||
|     mx.eval(x) | ||||
|  | ||||
|     def rope_mat(x): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user