| 
									
										
										
										
											2024-03-20 01:21:33 +01:00
										 |  |  | import argparse | 
					
						
							| 
									
										
										
										
											2024-01-09 08:44:51 -08:00
										 |  |  | from typing import List | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import model | 
					
						
							|  |  |  | import numpy as np | 
					
						
							|  |  |  | from transformers import AutoModel, AutoTokenizer | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def run_torch(bert_model: str, batch: List[str]): | 
					
						
							|  |  |  |     tokenizer = AutoTokenizer.from_pretrained(bert_model) | 
					
						
							|  |  |  |     torch_model = AutoModel.from_pretrained(bert_model) | 
					
						
							|  |  |  |     torch_tokens = tokenizer(batch, return_tensors="pt", padding=True) | 
					
						
							|  |  |  |     torch_forward = torch_model(**torch_tokens) | 
					
						
							|  |  |  |     torch_output = torch_forward.last_hidden_state.detach().numpy() | 
					
						
							|  |  |  |     torch_pooled = torch_forward.pooler_output.detach().numpy() | 
					
						
							|  |  |  |     return torch_output, torch_pooled | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | if __name__ == "__main__": | 
					
						
							| 
									
										
										
										
											2024-03-20 01:21:33 +01:00
										 |  |  |     parser = argparse.ArgumentParser( | 
					
						
							|  |  |  |         description="Run a BERT-like model for a batch of text." | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--bert-model", | 
					
						
							|  |  |  |         type=str, | 
					
						
							|  |  |  |         default="bert-base-uncased", | 
					
						
							|  |  |  |         help="The model identifier for a BERT-like model from Hugging Face Transformers.", | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--mlx-model", | 
					
						
							|  |  |  |         type=str, | 
					
						
							|  |  |  |         default="weights/bert-base-uncased.npz", | 
					
						
							|  |  |  |         help="The path of the stored MLX BERT weights (npz file).", | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--text", | 
					
						
							|  |  |  |         nargs="+", | 
					
						
							|  |  |  |         default=["This is an example of BERT working in MLX."], | 
					
						
							|  |  |  |         help="A batch of texts to process. Multiple texts should be separated by spaces.", | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     args = parser.parse_args() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     torch_output, torch_pooled = run_torch(args.bert_model, args.text) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     mlx_output, mlx_pooled = model.run(args.bert_model, args.mlx_model, args.text) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if torch_pooled is not None and mlx_pooled is not None: | 
					
						
							|  |  |  |         assert np.allclose( | 
					
						
							|  |  |  |             torch_output, mlx_output, rtol=1e-4, atol=1e-5 | 
					
						
							|  |  |  |         ), "Model output is different" | 
					
						
							|  |  |  |         assert np.allclose( | 
					
						
							|  |  |  |             torch_pooled, mlx_pooled, rtol=1e-4, atol=1e-5 | 
					
						
							|  |  |  |         ), "Model pooled output is different" | 
					
						
							|  |  |  |         print("Tests pass :)") | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         print("Pooled outputs were not compared due to one or both being None.") |