| 
									
										
										
										
											2023-12-20 10:22:25 -08:00
										 |  |  | import argparse | 
					
						
							| 
									
										
										
										
											2024-01-09 08:44:51 -08:00
										 |  |  | from pathlib import Path | 
					
						
							| 
									
										
										
										
											2024-01-12 11:15:09 -08:00
										 |  |  | from typing import List, Optional, Tuple | 
					
						
							| 
									
										
										
										
											2023-12-08 05:14:11 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  | import mlx.core as mx | 
					
						
							|  |  |  | import mlx.nn as nn | 
					
						
							| 
									
										
										
										
											2023-12-20 10:22:25 -08:00
										 |  |  | from mlx.utils import tree_unflatten | 
					
						
							| 
									
										
										
										
											2024-03-20 01:21:33 +01:00
										 |  |  | from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase | 
					
						
							| 
									
										
										
										
											2023-12-08 05:14:11 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-09 10:41:15 -05:00
										 |  |  | class TransformerEncoderLayer(nn.Module): | 
					
						
							| 
									
										
										
										
											2023-12-08 05:14:11 -05:00
										 |  |  |     """
 | 
					
						
							|  |  |  |     A transformer encoder layer with (the original BERT) post-normalization. | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         dims: int, | 
					
						
							|  |  |  |         num_heads: int, | 
					
						
							|  |  |  |         mlp_dims: Optional[int] = None, | 
					
						
							|  |  |  |         layer_norm_eps: float = 1e-12, | 
					
						
							|  |  |  |     ): | 
					
						
							|  |  |  |         super().__init__() | 
					
						
							|  |  |  |         mlp_dims = mlp_dims or dims * 4 | 
					
						
							| 
									
										
										
										
											2023-12-09 12:07:33 -05:00
										 |  |  |         self.attention = nn.MultiHeadAttention(dims, num_heads, bias=True) | 
					
						
							| 
									
										
										
										
											2023-12-09 10:41:15 -05:00
										 |  |  |         self.ln1 = nn.LayerNorm(dims, eps=layer_norm_eps) | 
					
						
							|  |  |  |         self.ln2 = nn.LayerNorm(dims, eps=layer_norm_eps) | 
					
						
							|  |  |  |         self.linear1 = nn.Linear(dims, mlp_dims) | 
					
						
							|  |  |  |         self.linear2 = nn.Linear(mlp_dims, dims) | 
					
						
							| 
									
										
										
										
											2023-12-08 05:14:11 -05:00
										 |  |  |         self.gelu = nn.GELU() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __call__(self, x, mask): | 
					
						
							|  |  |  |         attention_out = self.attention(x, x, x, mask) | 
					
						
							|  |  |  |         add_and_norm = self.ln1(x + attention_out) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         ff = self.linear1(add_and_norm) | 
					
						
							|  |  |  |         ff_gelu = self.gelu(ff) | 
					
						
							|  |  |  |         ff_out = self.linear2(ff_gelu) | 
					
						
							|  |  |  |         x = self.ln2(ff_out + add_and_norm) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return x | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-09 10:41:15 -05:00
										 |  |  | class TransformerEncoder(nn.Module): | 
					
						
							| 
									
										
										
										
											2023-12-08 05:14:11 -05:00
										 |  |  |     def __init__( | 
					
						
							|  |  |  |         self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None | 
					
						
							|  |  |  |     ): | 
					
						
							|  |  |  |         super().__init__() | 
					
						
							|  |  |  |         self.layers = [ | 
					
						
							|  |  |  |             TransformerEncoderLayer(dims, num_heads, mlp_dims) | 
					
						
							|  |  |  |             for i in range(num_layers) | 
					
						
							|  |  |  |         ] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __call__(self, x, mask): | 
					
						
							| 
									
										
										
										
											2023-12-09 10:41:15 -05:00
										 |  |  |         for layer in self.layers: | 
					
						
							|  |  |  |             x = layer(x, mask) | 
					
						
							| 
									
										
										
										
											2023-12-08 05:14:11 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  |         return x | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class BertEmbeddings(nn.Module): | 
					
						
							| 
									
										
										
										
											2024-03-20 01:21:33 +01:00
										 |  |  |     def __init__(self, config): | 
					
						
							| 
									
										
										
										
											2024-03-13 10:24:21 -07:00
										 |  |  |         super().__init__() | 
					
						
							| 
									
										
										
										
											2024-03-20 01:21:33 +01:00
										 |  |  |         self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) | 
					
						
							|  |  |  |         self.token_type_embeddings = nn.Embedding( | 
					
						
							|  |  |  |             config.type_vocab_size, config.hidden_size | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2023-12-08 05:14:11 -05:00
										 |  |  |         self.position_embeddings = nn.Embedding( | 
					
						
							| 
									
										
										
										
											2024-03-20 01:21:33 +01:00
										 |  |  |             config.max_position_embeddings, config.hidden_size | 
					
						
							| 
									
										
										
										
											2023-12-08 05:14:11 -05:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-03-20 01:21:33 +01:00
										 |  |  |         self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) | 
					
						
							| 
									
										
										
										
											2023-12-08 05:14:11 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-20 01:21:33 +01:00
										 |  |  |     def __call__( | 
					
						
							|  |  |  |         self, input_ids: mx.array, token_type_ids: mx.array = None | 
					
						
							|  |  |  |     ) -> mx.array: | 
					
						
							| 
									
										
										
										
											2023-12-08 05:14:11 -05:00
										 |  |  |         words = self.word_embeddings(input_ids) | 
					
						
							|  |  |  |         position = self.position_embeddings( | 
					
						
							|  |  |  |             mx.broadcast_to(mx.arange(input_ids.shape[1]), input_ids.shape) | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-03-20 01:21:33 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |         if token_type_ids is None: | 
					
						
							|  |  |  |             # If token_type_ids is not provided, default to zeros | 
					
						
							|  |  |  |             token_type_ids = mx.zeros_like(input_ids) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-08 05:14:11 -05:00
										 |  |  |         token_types = self.token_type_embeddings(token_type_ids) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         embeddings = position + words + token_types | 
					
						
							|  |  |  |         return self.norm(embeddings) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class Bert(nn.Module): | 
					
						
							| 
									
										
										
										
											2024-03-20 01:21:33 +01:00
										 |  |  |     def __init__(self, config): | 
					
						
							| 
									
										
										
										
											2024-03-13 10:24:21 -07:00
										 |  |  |         super().__init__() | 
					
						
							| 
									
										
										
										
											2023-12-08 05:14:11 -05:00
										 |  |  |         self.embeddings = BertEmbeddings(config) | 
					
						
							|  |  |  |         self.encoder = TransformerEncoder( | 
					
						
							|  |  |  |             num_layers=config.num_hidden_layers, | 
					
						
							| 
									
										
										
										
											2024-03-20 01:21:33 +01:00
										 |  |  |             dims=config.hidden_size, | 
					
						
							| 
									
										
										
										
											2023-12-08 05:14:11 -05:00
										 |  |  |             num_heads=config.num_attention_heads, | 
					
						
							| 
									
										
										
										
											2024-03-20 01:21:33 +01:00
										 |  |  |             mlp_dims=config.intermediate_size, | 
					
						
							| 
									
										
										
										
											2023-12-08 05:14:11 -05:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-03-20 01:21:33 +01:00
										 |  |  |         self.pooler = nn.Linear(config.hidden_size, config.hidden_size) | 
					
						
							| 
									
										
										
										
											2023-12-08 05:14:11 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def __call__( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         input_ids: mx.array, | 
					
						
							| 
									
										
										
										
											2024-03-20 01:21:33 +01:00
										 |  |  |         token_type_ids: mx.array = None, | 
					
						
							| 
									
										
										
										
											2023-12-09 12:07:33 -05:00
										 |  |  |         attention_mask: mx.array = None, | 
					
						
							| 
									
										
										
										
											2024-01-12 11:15:09 -08:00
										 |  |  |     ) -> Tuple[mx.array, mx.array]: | 
					
						
							| 
									
										
										
										
											2023-12-08 05:14:11 -05:00
										 |  |  |         x = self.embeddings(input_ids, token_type_ids) | 
					
						
							| 
									
										
										
										
											2023-12-09 12:07:33 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  |         if attention_mask is not None: | 
					
						
							|  |  |  |             # convert 0's to -infs, 1's to 0's, and make it broadcastable | 
					
						
							| 
									
										
										
										
											2023-12-09 21:21:24 -05:00
										 |  |  |             attention_mask = mx.log(attention_mask) | 
					
						
							| 
									
										
										
										
											2023-12-09 12:07:33 -05:00
										 |  |  |             attention_mask = mx.expand_dims(attention_mask, (1, 2)) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-08 05:14:11 -05:00
										 |  |  |         y = self.encoder(x, attention_mask) | 
					
						
							|  |  |  |         return y, mx.tanh(self.pooler(y[:, 0])) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-20 01:21:33 +01:00
										 |  |  | def load_model( | 
					
						
							|  |  |  |     bert_model: str, weights_path: str | 
					
						
							|  |  |  | ) -> Tuple[Bert, PreTrainedTokenizerBase]: | 
					
						
							| 
									
										
										
										
											2024-01-09 08:44:51 -08:00
										 |  |  |     if not Path(weights_path).exists(): | 
					
						
							|  |  |  |         raise ValueError(f"No model weights found in {weights_path}") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-20 01:21:33 +01:00
										 |  |  |     config = AutoConfig.from_pretrained(bert_model) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-09 10:41:15 -05:00
										 |  |  |     # create and update the model | 
					
						
							| 
									
										
										
										
											2024-03-20 01:21:33 +01:00
										 |  |  |     model = Bert(config) | 
					
						
							| 
									
										
										
										
											2024-01-09 08:44:51 -08:00
										 |  |  |     model.load_weights(weights_path) | 
					
						
							| 
									
										
										
										
											2023-12-09 10:41:15 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-20 01:21:33 +01:00
										 |  |  |     tokenizer = AutoTokenizer.from_pretrained(bert_model) | 
					
						
							| 
									
										
										
										
											2023-12-09 10:41:15 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  |     return model, tokenizer | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-09 08:44:51 -08:00
										 |  |  | def run(bert_model: str, mlx_model: str, batch: List[str]): | 
					
						
							| 
									
										
										
										
											2023-12-09 10:41:15 -05:00
										 |  |  |     model, tokenizer = load_model(bert_model, mlx_model) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-08 05:14:11 -05:00
										 |  |  |     tokens = tokenizer(batch, return_tensors="np", padding=True) | 
					
						
							|  |  |  |     tokens = {key: mx.array(v) for key, v in tokens.items()} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-09 08:44:51 -08:00
										 |  |  |     return model(**tokens) | 
					
						
							| 
									
										
										
										
											2023-12-08 05:14:11 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | if __name__ == "__main__": | 
					
						
							| 
									
										
										
										
											2023-12-09 10:41:15 -05:00
										 |  |  |     parser = argparse.ArgumentParser(description="Run the BERT model using MLX.") | 
					
						
							| 
									
										
										
										
											2023-12-08 05:14:11 -05:00
										 |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--bert-model", | 
					
						
							|  |  |  |         type=str, | 
					
						
							|  |  |  |         default="bert-base-uncased", | 
					
						
							|  |  |  |         help="The huggingface name of the BERT model to save.", | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--mlx-model", | 
					
						
							|  |  |  |         type=str, | 
					
						
							|  |  |  |         default="weights/bert-base-uncased.npz", | 
					
						
							| 
									
										
										
										
											2023-12-09 10:41:15 -05:00
										 |  |  |         help="The path of the stored MLX BERT weights (npz file).", | 
					
						
							| 
									
										
										
										
											2023-12-08 05:14:11 -05:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2024-01-09 08:44:51 -08:00
										 |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--text", | 
					
						
							|  |  |  |         type=str, | 
					
						
							|  |  |  |         default="This is an example of BERT working in MLX", | 
					
						
							|  |  |  |         help="The text to generate embeddings for.", | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-12-08 05:14:11 -05:00
										 |  |  |     args = parser.parse_args() | 
					
						
							| 
									
										
										
										
											2024-01-09 08:44:51 -08:00
										 |  |  |     run(args.bert_model, args.mlx_model, args.text) |