mirror of
				https://github.com/ml-explore/mlx-examples.git
				synced 2025-10-31 19:18:09 +08:00 
			
		
		
		
	 4680ef4413
			
		
	
	4680ef4413
	
	
	
		
			
			* Update convert.py * Update model.py * Update test.py * Update model.py * Update convert.py * Add files via upload * Update convert.py * format * nit * nit --------- Co-authored-by: Awni Hannun <awni@apple.com>
		
			
				
	
	
		
			48 lines
		
	
	
		
			1.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			48 lines
		
	
	
		
			1.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import argparse
 | |
| 
 | |
| import numpy
 | |
| from transformers import AutoModel
 | |
| 
 | |
| 
 | |
| def replace_key(key: str) -> str:
 | |
|     key = key.replace(".layer.", ".layers.")
 | |
|     key = key.replace(".self.key.", ".key_proj.")
 | |
|     key = key.replace(".self.query.", ".query_proj.")
 | |
|     key = key.replace(".self.value.", ".value_proj.")
 | |
|     key = key.replace(".attention.output.dense.", ".attention.out_proj.")
 | |
|     key = key.replace(".attention.output.LayerNorm.", ".ln1.")
 | |
|     key = key.replace(".output.LayerNorm.", ".ln2.")
 | |
|     key = key.replace(".intermediate.dense.", ".linear1.")
 | |
|     key = key.replace(".output.dense.", ".linear2.")
 | |
|     key = key.replace(".LayerNorm.", ".norm.")
 | |
|     key = key.replace("pooler.dense.", "pooler.")
 | |
|     return key
 | |
| 
 | |
| 
 | |
| def convert(bert_model: str, mlx_model: str) -> None:
 | |
|     model = AutoModel.from_pretrained(bert_model)
 | |
|     # save the tensors
 | |
|     tensors = {
 | |
|         replace_key(key): tensor.numpy() for key, tensor in model.state_dict().items()
 | |
|     }
 | |
|     numpy.savez(mlx_model, **tensors)
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     parser = argparse.ArgumentParser(description="Convert BERT weights to MLX.")
 | |
|     parser.add_argument(
 | |
|         "--bert-model",
 | |
|         type=str,
 | |
|         default="bert-base-uncased",
 | |
|         help="The huggingface name of the BERT model to save. Any BERT-like model can be specified.",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--mlx-model",
 | |
|         type=str,
 | |
|         default="weights/bert-base-uncased.npz",
 | |
|         help="The output path for the MLX BERT weights.",
 | |
|     )
 | |
|     args = parser.parse_args()
 | |
| 
 | |
|     convert(args.bert_model, args.mlx_model)
 |