| 
									
										
										
										
											2023-12-09 10:41:15 -05:00
										 |  |  | # BERT
 | 
					
						
							| 
									
										
										
										
											2023-12-08 05:14:11 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-20 01:21:33 +01:00
										 |  |  | An implementation of BERT [(Devlin, et al., 2019)](https://aclanthology.org/N19-1423/) in MLX. | 
					
						
							| 
									
										
										
										
											2023-12-08 05:14:11 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-09 08:44:51 -08:00
										 |  |  | ## Setup 
 | 
					
						
							| 
									
										
										
										
											2023-12-08 05:14:11 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-09 08:44:51 -08:00
										 |  |  | Install the requirements: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ``` | 
					
						
							|  |  |  | pip install -r requirements.txt | 
					
						
							|  |  |  | ``` | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | Then convert the weights with: | 
					
						
							| 
									
										
										
										
											2023-12-08 05:14:11 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  | ``` | 
					
						
							|  |  |  | python convert.py \ | 
					
						
							| 
									
										
										
										
											2023-12-13 11:37:02 -05:00
										 |  |  |     --bert-model bert-base-uncased \ | 
					
						
							| 
									
										
										
										
											2023-12-08 05:14:11 -05:00
										 |  |  |     --mlx-model weights/bert-base-uncased.npz | 
					
						
							|  |  |  | ``` | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-09 10:48:34 -05:00
										 |  |  | ## Usage
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | To use the `Bert` model in your own code, you can load it with: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ```python | 
					
						
							| 
									
										
										
										
											2024-01-25 21:44:53 +03:00
										 |  |  | import mlx.core as mx | 
					
						
							| 
									
										
										
										
											2023-12-09 10:48:34 -05:00
										 |  |  | from model import Bert, load_model | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | model, tokenizer = load_model( | 
					
						
							|  |  |  |     "bert-base-uncased", | 
					
						
							|  |  |  |     "weights/bert-base-uncased.npz") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | batch = ["This is an example of BERT working on MLX."] | 
					
						
							|  |  |  | tokens = tokenizer(batch, return_tensors="np", padding=True) | 
					
						
							|  |  |  | tokens = {key: mx.array(v) for key, v in tokens.items()} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | output, pooled = model(**tokens) | 
					
						
							|  |  |  | ``` | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-09 08:44:51 -08:00
										 |  |  | The `output` contains a `Batch x Tokens x Dims` tensor, representing a vector | 
					
						
							| 
									
										
										
										
											2024-03-20 01:21:33 +01:00
										 |  |  | for every input token. If you want to train anything at the **token-level**, | 
					
						
							|  |  |  | use this. | 
					
						
							| 
									
										
										
										
											2023-12-09 10:48:34 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-09 08:44:51 -08:00
										 |  |  | The `pooled` contains a `Batch x Dims` tensor, which is the pooled | 
					
						
							|  |  |  | representation for each input. If you want to train a **classification** | 
					
						
							| 
									
										
										
										
											2024-03-20 01:21:33 +01:00
										 |  |  | model, use this. | 
					
						
							| 
									
										
										
										
											2023-12-08 05:14:11 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-09 08:44:51 -08:00
										 |  |  | ## Test
 | 
					
						
							| 
									
										
										
										
											2023-12-08 05:14:11 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-09 08:44:51 -08:00
										 |  |  | You can check the output for the default model (`bert-base-uncased`) matches the | 
					
						
							|  |  |  | Hugging Face version with: | 
					
						
							| 
									
										
										
										
											2023-12-08 05:14:11 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  | ``` | 
					
						
							| 
									
										
										
										
											2024-01-09 08:44:51 -08:00
										 |  |  | python test.py | 
					
						
							| 
									
										
										
										
											2023-12-08 05:14:11 -05:00
										 |  |  | ``` |