From 4680ef44134509ce85433dd2d3152a5fd0fbfcf4 Mon Sep 17 00:00:00 2001 From: yzimmermann <92678727+yzimmermann@users.noreply.github.com> Date: Wed, 20 Mar 2024 01:21:33 +0100 Subject: [PATCH] Enable more BERT models (#580) * Update convert.py * Update model.py * Update test.py * Update model.py * Update convert.py * Add files via upload * Update convert.py * format * nit * nit --------- Co-authored-by: Awni Hannun --- bert/README.md | 8 +++--- bert/convert.py | 13 +++------- bert/model.py | 68 +++++++++++++++++++++---------------------------- bert/test.py | 55 +++++++++++++++++++++++++++------------ 4 files changed, 76 insertions(+), 68 deletions(-) diff --git a/bert/README.md b/bert/README.md index e03df7d0..cf4d8db4 100644 --- a/bert/README.md +++ b/bert/README.md @@ -1,6 +1,6 @@ # BERT -An implementation of BERT [(Devlin, et al., 2019)](https://aclanthology.org/N19-1423/) within MLX. +An implementation of BERT [(Devlin, et al., 2019)](https://aclanthology.org/N19-1423/) in MLX. ## Setup @@ -38,12 +38,12 @@ output, pooled = model(**tokens) ``` The `output` contains a `Batch x Tokens x Dims` tensor, representing a vector -for every input token. If you want to train anything at a **token-level**, -you'll want to use this. +for every input token. If you want to train anything at the **token-level**, +use this. The `pooled` contains a `Batch x Dims` tensor, which is the pooled representation for each input. If you want to train a **classification** -model, you'll want to use this. +model, use this. ## Test diff --git a/bert/convert.py b/bert/convert.py index 4cad3cc6..63b448b4 100644 --- a/bert/convert.py +++ b/bert/convert.py @@ -1,7 +1,7 @@ import argparse import numpy -from transformers import BertModel +from transformers import AutoModel def replace_key(key: str) -> str: @@ -20,7 +20,7 @@ def replace_key(key: str) -> str: def convert(bert_model: str, mlx_model: str) -> None: - model = BertModel.from_pretrained(bert_model) + model = AutoModel.from_pretrained(bert_model) # save the tensors tensors = { replace_key(key): tensor.numpy() for key, tensor in model.state_dict().items() @@ -32,14 +32,9 @@ if __name__ == "__main__": parser = argparse.ArgumentParser(description="Convert BERT weights to MLX.") parser.add_argument( "--bert-model", - choices=[ - "bert-base-uncased", - "bert-base-cased", - "bert-large-uncased", - "bert-large-cased", - ], + type=str, default="bert-base-uncased", - help="The huggingface name of the BERT model to save.", + help="The huggingface name of the BERT model to save. Any BERT-like model can be specified.", ) parser.add_argument( "--mlx-model", diff --git a/bert/model.py b/bert/model.py index 381dd5e2..2580c365 100644 --- a/bert/model.py +++ b/bert/model.py @@ -8,31 +8,7 @@ import mlx.nn as nn import numpy import numpy as np from mlx.utils import tree_unflatten -from transformers import BertTokenizer - - -@dataclass -class ModelArgs: - dim: int = 768 - num_attention_heads: int = 12 - num_hidden_layers: int = 12 - vocab_size: int = 30522 - attention_probs_dropout_prob: float = 0.1 - hidden_dropout_prob: float = 0.1 - layer_norm_eps: float = 1e-12 - max_position_embeddings: int = 512 - - -model_configs = { - "bert-base-uncased": ModelArgs(), - "bert-base-cased": ModelArgs(), - "bert-large-uncased": ModelArgs( - dim=1024, num_attention_heads=16, num_hidden_layers=24 - ), - "bert-large-cased": ModelArgs( - dim=1024, num_attention_heads=16, num_hidden_layers=24 - ), -} +from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase class TransformerEncoderLayer(nn.Module): @@ -86,20 +62,29 @@ class TransformerEncoder(nn.Module): class BertEmbeddings(nn.Module): - def __init__(self, config: ModelArgs): + def __init__(self, config): super().__init__() - self.word_embeddings = nn.Embedding(config.vocab_size, config.dim) - self.token_type_embeddings = nn.Embedding(2, config.dim) - self.position_embeddings = nn.Embedding( - config.max_position_embeddings, config.dim + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + self.token_type_embeddings = nn.Embedding( + config.type_vocab_size, config.hidden_size ) - self.norm = nn.LayerNorm(config.dim, eps=config.layer_norm_eps) + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size + ) + self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - def __call__(self, input_ids: mx.array, token_type_ids: mx.array) -> mx.array: + def __call__( + self, input_ids: mx.array, token_type_ids: mx.array = None + ) -> mx.array: words = self.word_embeddings(input_ids) position = self.position_embeddings( mx.broadcast_to(mx.arange(input_ids.shape[1]), input_ids.shape) ) + + if token_type_ids is None: + # If token_type_ids is not provided, default to zeros + token_type_ids = mx.zeros_like(input_ids) + token_types = self.token_type_embeddings(token_type_ids) embeddings = position + words + token_types @@ -107,20 +92,21 @@ class BertEmbeddings(nn.Module): class Bert(nn.Module): - def __init__(self, config: ModelArgs): + def __init__(self, config): super().__init__() self.embeddings = BertEmbeddings(config) self.encoder = TransformerEncoder( num_layers=config.num_hidden_layers, - dims=config.dim, + dims=config.hidden_size, num_heads=config.num_attention_heads, + mlp_dims=config.intermediate_size, ) - self.pooler = nn.Linear(config.dim, config.dim) + self.pooler = nn.Linear(config.hidden_size, config.hidden_size) def __call__( self, input_ids: mx.array, - token_type_ids: mx.array, + token_type_ids: mx.array = None, attention_mask: mx.array = None, ) -> Tuple[mx.array, mx.array]: x = self.embeddings(input_ids, token_type_ids) @@ -134,15 +120,19 @@ class Bert(nn.Module): return y, mx.tanh(self.pooler(y[:, 0])) -def load_model(bert_model: str, weights_path: str) -> Tuple[Bert, BertTokenizer]: +def load_model( + bert_model: str, weights_path: str +) -> Tuple[Bert, PreTrainedTokenizerBase]: if not Path(weights_path).exists(): raise ValueError(f"No model weights found in {weights_path}") + config = AutoConfig.from_pretrained(bert_model) + # create and update the model - model = Bert(model_configs[bert_model]) + model = Bert(config) model.load_weights(weights_path) - tokenizer = BertTokenizer.from_pretrained(bert_model) + tokenizer = AutoTokenizer.from_pretrained(bert_model) return model, tokenizer diff --git a/bert/test.py b/bert/test.py index 089fc45f..e5462ba3 100644 --- a/bert/test.py +++ b/bert/test.py @@ -1,3 +1,4 @@ +import argparse from typing import List import model @@ -16,19 +17,41 @@ def run_torch(bert_model: str, batch: List[str]): if __name__ == "__main__": - bert_model = "bert-base-uncased" - mlx_model = "weights/bert-base-uncased.npz" - batch = [ - "This is an example of BERT working in MLX.", - "A second string", - "This is another string.", - ] - torch_output, torch_pooled = run_torch(bert_model, batch) - mlx_output, mlx_pooled = model.run(bert_model, mlx_model, batch) - assert np.allclose( - torch_output, mlx_output, rtol=1e-4, atol=1e-5 - ), "Model output is different" - assert np.allclose( - torch_pooled, mlx_pooled, rtol=1e-4, atol=1e-5 - ), "Model pooled output is different" - print("Tests pass :)") + parser = argparse.ArgumentParser( + description="Run a BERT-like model for a batch of text." + ) + parser.add_argument( + "--bert-model", + type=str, + default="bert-base-uncased", + help="The model identifier for a BERT-like model from Hugging Face Transformers.", + ) + parser.add_argument( + "--mlx-model", + type=str, + default="weights/bert-base-uncased.npz", + help="The path of the stored MLX BERT weights (npz file).", + ) + parser.add_argument( + "--text", + nargs="+", + default=["This is an example of BERT working in MLX."], + help="A batch of texts to process. Multiple texts should be separated by spaces.", + ) + + args = parser.parse_args() + + torch_output, torch_pooled = run_torch(args.bert_model, args.text) + + mlx_output, mlx_pooled = model.run(args.bert_model, args.mlx_model, args.text) + + if torch_pooled is not None and mlx_pooled is not None: + assert np.allclose( + torch_output, mlx_output, rtol=1e-4, atol=1e-5 + ), "Model output is different" + assert np.allclose( + torch_pooled, mlx_pooled, rtol=1e-4, atol=1e-5 + ), "Model pooled output is different" + print("Tests pass :)") + else: + print("Pooled outputs were not compared due to one or both being None.")