mirror of
				https://github.com/ml-explore/mlx-examples.git
				synced 2025-10-31 19:18:09 +08:00 
			
		
		
		
	 4680ef4413
			
		
	
	4680ef4413
	
	
	
		
			
			* Update convert.py * Update model.py * Update test.py * Update model.py * Update convert.py * Add files via upload * Update convert.py * format * nit * nit --------- Co-authored-by: Awni Hannun <awni@apple.com>
		
			
				
	
	
		
			58 lines
		
	
	
		
			1.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			58 lines
		
	
	
		
			1.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import argparse
 | |
| from typing import List
 | |
| 
 | |
| import model
 | |
| import numpy as np
 | |
| from transformers import AutoModel, AutoTokenizer
 | |
| 
 | |
| 
 | |
| def run_torch(bert_model: str, batch: List[str]):
 | |
|     tokenizer = AutoTokenizer.from_pretrained(bert_model)
 | |
|     torch_model = AutoModel.from_pretrained(bert_model)
 | |
|     torch_tokens = tokenizer(batch, return_tensors="pt", padding=True)
 | |
|     torch_forward = torch_model(**torch_tokens)
 | |
|     torch_output = torch_forward.last_hidden_state.detach().numpy()
 | |
|     torch_pooled = torch_forward.pooler_output.detach().numpy()
 | |
|     return torch_output, torch_pooled
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     parser = argparse.ArgumentParser(
 | |
|         description="Run a BERT-like model for a batch of text."
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--bert-model",
 | |
|         type=str,
 | |
|         default="bert-base-uncased",
 | |
|         help="The model identifier for a BERT-like model from Hugging Face Transformers.",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--mlx-model",
 | |
|         type=str,
 | |
|         default="weights/bert-base-uncased.npz",
 | |
|         help="The path of the stored MLX BERT weights (npz file).",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--text",
 | |
|         nargs="+",
 | |
|         default=["This is an example of BERT working in MLX."],
 | |
|         help="A batch of texts to process. Multiple texts should be separated by spaces.",
 | |
|     )
 | |
| 
 | |
|     args = parser.parse_args()
 | |
| 
 | |
|     torch_output, torch_pooled = run_torch(args.bert_model, args.text)
 | |
| 
 | |
|     mlx_output, mlx_pooled = model.run(args.bert_model, args.mlx_model, args.text)
 | |
| 
 | |
|     if torch_pooled is not None and mlx_pooled is not None:
 | |
|         assert np.allclose(
 | |
|             torch_output, mlx_output, rtol=1e-4, atol=1e-5
 | |
|         ), "Model output is different"
 | |
|         assert np.allclose(
 | |
|             torch_pooled, mlx_pooled, rtol=1e-4, atol=1e-5
 | |
|         ), "Model pooled output is different"
 | |
|         print("Tests pass :)")
 | |
|     else:
 | |
|         print("Pooled outputs were not compared due to one or both being None.")
 |