mirror of
				https://github.com/ml-explore/mlx-examples.git
				synced 2025-11-01 03:28:08 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			54 lines
		
	
	
		
			1.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			54 lines
		
	
	
		
			1.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright © 2023 Apple Inc.
 | |
| 
 | |
| import argparse
 | |
| from itertools import starmap
 | |
| 
 | |
| import numpy as np
 | |
| import torch
 | |
| 
 | |
| 
 | |
| def map_torch_to_mlx(key, value):
 | |
|     if "tok_embedding" in key:
 | |
|         key = "embedding.weight"
 | |
| 
 | |
|     elif "norm" in key:
 | |
|         key = key.replace("attention_norm", "norm1").replace("ffn_norm", "norm2")
 | |
| 
 | |
|     elif "wq" in key or "wk" in key or "wv" in key or "wo" in key:
 | |
|         key = key.replace("wq", "query_proj")
 | |
|         key = key.replace("wk", "key_proj")
 | |
|         key = key.replace("wv", "value_proj")
 | |
|         key = key.replace("wo", "out_proj")
 | |
| 
 | |
|     elif "w1" in key or "w2" in key or "w3" in key:
 | |
|         # The FFN is a separate submodule in PyTorch
 | |
|         key = key.replace("feed_forward.w1", "linear1")
 | |
|         key = key.replace("feed_forward.w3", "linear2")
 | |
|         key = key.replace("feed_forward.w2", "linear3")
 | |
| 
 | |
|     elif "output" in key:
 | |
|         key = key.replace("output", "out_proj")
 | |
| 
 | |
|     elif "rope" in key:
 | |
|         return None, None
 | |
| 
 | |
|     return (
 | |
|         key,
 | |
|         value.numpy()
 | |
|         if value.dtype != torch.bfloat16
 | |
|         else value.to(torch.float32).numpy(),
 | |
|     )
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
 | |
|     parser.add_argument("torch_weights")
 | |
|     parser.add_argument("output_file")
 | |
|     args = parser.parse_args()
 | |
| 
 | |
|     state = torch.load(args.torch_weights, map_location=torch.device('cpu'))
 | |
|     np.savez(
 | |
|         args.output_file,
 | |
|         **{k: v for k, v in starmap(map_torch_to_mlx, state.items()) if k is not None}
 | |
|     )
 | 
