| 
									
										
										
										
											2023-12-08 05:14:11 -05:00
										 |  |  | import argparse | 
					
						
							| 
									
										
										
										
											2023-12-20 10:22:25 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-08 05:14:11 -05:00
										 |  |  | import numpy | 
					
						
							| 
									
										
										
										
											2024-03-20 01:21:33 +01:00
										 |  |  | from transformers import AutoModel | 
					
						
							| 
									
										
										
										
											2023-12-08 05:14:11 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def replace_key(key: str) -> str: | 
					
						
							|  |  |  |     key = key.replace(".layer.", ".layers.") | 
					
						
							|  |  |  |     key = key.replace(".self.key.", ".key_proj.") | 
					
						
							|  |  |  |     key = key.replace(".self.query.", ".query_proj.") | 
					
						
							|  |  |  |     key = key.replace(".self.value.", ".value_proj.") | 
					
						
							|  |  |  |     key = key.replace(".attention.output.dense.", ".attention.out_proj.") | 
					
						
							|  |  |  |     key = key.replace(".attention.output.LayerNorm.", ".ln1.") | 
					
						
							|  |  |  |     key = key.replace(".output.LayerNorm.", ".ln2.") | 
					
						
							|  |  |  |     key = key.replace(".intermediate.dense.", ".linear1.") | 
					
						
							|  |  |  |     key = key.replace(".output.dense.", ".linear2.") | 
					
						
							|  |  |  |     key = key.replace(".LayerNorm.", ".norm.") | 
					
						
							|  |  |  |     key = key.replace("pooler.dense.", "pooler.") | 
					
						
							|  |  |  |     return key | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def convert(bert_model: str, mlx_model: str) -> None: | 
					
						
							| 
									
										
										
										
											2024-03-20 01:21:33 +01:00
										 |  |  |     model = AutoModel.from_pretrained(bert_model) | 
					
						
							| 
									
										
										
										
											2023-12-08 05:14:11 -05:00
										 |  |  |     # save the tensors | 
					
						
							|  |  |  |     tensors = { | 
					
						
							|  |  |  |         replace_key(key): tensor.numpy() for key, tensor in model.state_dict().items() | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     numpy.savez(mlx_model, **tensors) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | if __name__ == "__main__": | 
					
						
							|  |  |  |     parser = argparse.ArgumentParser(description="Convert BERT weights to MLX.") | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--bert-model", | 
					
						
							| 
									
										
										
										
											2024-03-20 01:21:33 +01:00
										 |  |  |         type=str, | 
					
						
							| 
									
										
										
										
											2023-12-08 05:14:11 -05:00
										 |  |  |         default="bert-base-uncased", | 
					
						
							| 
									
										
										
										
											2024-03-20 01:21:33 +01:00
										 |  |  |         help="The huggingface name of the BERT model to save. Any BERT-like model can be specified.", | 
					
						
							| 
									
										
										
										
											2023-12-08 05:14:11 -05:00
										 |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--mlx-model", | 
					
						
							|  |  |  |         type=str, | 
					
						
							|  |  |  |         default="weights/bert-base-uncased.npz", | 
					
						
							|  |  |  |         help="The output path for the MLX BERT weights.", | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     args = parser.parse_args() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-09 14:15:25 -08:00
										 |  |  |     convert(args.bert_model, args.mlx_model) |