Some fixes / cleanup for BERT example (#269)

* some fixes/cleaning for bert + test

* nit
This commit is contained in:
Awni Hannun 2024-01-09 08:44:51 -08:00 committed by GitHub
parent 6759dfddf1
commit bbd7172eef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 77 additions and 117 deletions

View File

@ -2,9 +2,15 @@
An implementation of BERT [(Devlin, et al., 2019)](https://aclanthology.org/N19-1423/) within MLX. An implementation of BERT [(Devlin, et al., 2019)](https://aclanthology.org/N19-1423/) within MLX.
## Downloading and Converting Weights ## Setup
The `convert.py` script relies on `transformers` to download the weights, and exports them as a single `.npz` file. Install the requirements:
```
pip install -r requirements.txt
```
Then convert the weights with:
``` ```
python convert.py \ python convert.py \
@ -30,49 +36,20 @@ tokens = {key: mx.array(v) for key, v in tokens.items()}
output, pooled = model(**tokens) output, pooled = model(**tokens)
``` ```
The `output` contains a `Batch x Tokens x Dims` tensor, representing a vector for every input token. The `output` contains a `Batch x Tokens x Dims` tensor, representing a vector
If you want to train anything at a **token-level**, you'll want to use this. for every input token. If you want to train anything at a **token-level**,
you'll want to use this.
The `pooled` contains a `Batch x Dims` tensor, which is the pooled representation for each input. The `pooled` contains a `Batch x Dims` tensor, which is the pooled
If you want to train a **classification** model, you'll want to use this. representation for each input. If you want to train a **classification**
model, you'll want to use this.
## Comparison with 🤗 `transformers` Implementation
In order to run the model, and have it forward inference on a batch of examples: ## Test
You can check the output for the default model (`bert-base-uncased`) matches the
Hugging Face version with:
```sh
python model.py \
--bert-model bert-base-uncased \
--mlx-model weights/bert-base-uncased.npz
``` ```
python test.py
Which will show the following outputs:
```
MLX BERT:
[[[-0.52508914 -0.1993871 -0.28210318 ... -0.61125606 0.19114694
0.8227601 ]
[-0.8783862 -0.37107834 -0.52238125 ... -0.5067165 1.0847603
0.31066895]
[-0.70010054 -0.5424497 -0.26593682 ... -0.2688697 0.38338926
0.6557663 ]
...
```
They can be compared against the 🤗 implementation with:
```sh
python hf_model.py \
--bert-model bert-base-uncased
```
Which will show:
```
HF BERT:
[[[-0.52508944 -0.1993877 -0.28210333 ... -0.6112575 0.19114678
0.8227603 ]
[-0.878387 -0.371079 -0.522381 ... -0.50671494 1.0847601
0.31066933]
[-0.7001008 -0.5424504 -0.26593733 ... -0.26887015 0.38339025
0.65576553]
...
``` ```

View File

@ -1,43 +0,0 @@
import argparse
from transformers import AutoModel, AutoTokenizer
def run(bert_model: str):
batch = [
"This is an example of BERT working on MLX.",
"A second string",
"This is another string.",
]
tokenizer = AutoTokenizer.from_pretrained(bert_model)
torch_model = AutoModel.from_pretrained(bert_model)
torch_tokens = tokenizer(batch, return_tensors="pt", padding=True)
torch_forward = torch_model(**torch_tokens)
torch_output = torch_forward.last_hidden_state.detach().numpy()
torch_pooled = torch_forward.pooler_output.detach().numpy()
print("\n HF BERT:")
print(torch_output)
print("\n\n HF Pooled:")
print(torch_pooled[0, :20])
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run the BERT model using Hugging Face Transformers."
)
parser.add_argument(
"--bert-model",
choices=[
"bert-base-uncased",
"bert-base-cased",
"bert-large-uncased",
"bert-large-cased",
],
default="bert-base-uncased",
help="The huggingface name of the BERT model to save.",
)
args = parser.parse_args()
run(args.bert_model)

View File

@ -1,6 +1,7 @@
import argparse import argparse
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from pathlib import Path
from typing import List, Optional
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -12,7 +13,7 @@ from transformers import BertTokenizer
@dataclass @dataclass
class ModelArgs: class ModelArgs:
intermediate_size: int = 768 dim: int = 768
num_attention_heads: int = 12 num_attention_heads: int = 12
num_hidden_layers: int = 12 num_hidden_layers: int = 12
vocab_size: int = 30522 vocab_size: int = 30522
@ -26,10 +27,10 @@ model_configs = {
"bert-base-uncased": ModelArgs(), "bert-base-uncased": ModelArgs(),
"bert-base-cased": ModelArgs(), "bert-base-cased": ModelArgs(),
"bert-large-uncased": ModelArgs( "bert-large-uncased": ModelArgs(
intermediate_size=1024, num_attention_heads=16, num_hidden_layers=24 dim=1024, num_attention_heads=16, num_hidden_layers=24
), ),
"bert-large-cased": ModelArgs( "bert-large-cased": ModelArgs(
intermediate_size=1024, num_attention_heads=16, num_hidden_layers=24 dim=1024, num_attention_heads=16, num_hidden_layers=24
), ),
} }
@ -86,12 +87,12 @@ class TransformerEncoder(nn.Module):
class BertEmbeddings(nn.Module): class BertEmbeddings(nn.Module):
def __init__(self, config: ModelArgs): def __init__(self, config: ModelArgs):
self.word_embeddings = nn.Embedding(config.vocab_size, config.intermediate_size) self.word_embeddings = nn.Embedding(config.vocab_size, config.dim)
self.token_type_embeddings = nn.Embedding(2, config.intermediate_size) self.token_type_embeddings = nn.Embedding(2, config.dim)
self.position_embeddings = nn.Embedding( self.position_embeddings = nn.Embedding(
config.max_position_embeddings, config.intermediate_size config.max_position_embeddings, config.dim
) )
self.norm = nn.LayerNorm(config.intermediate_size, eps=config.layer_norm_eps) self.norm = nn.LayerNorm(config.dim, eps=config.layer_norm_eps)
def __call__(self, input_ids: mx.array, token_type_ids: mx.array) -> mx.array: def __call__(self, input_ids: mx.array, token_type_ids: mx.array) -> mx.array:
words = self.word_embeddings(input_ids) words = self.word_embeddings(input_ids)
@ -109,10 +110,10 @@ class Bert(nn.Module):
self.embeddings = BertEmbeddings(config) self.embeddings = BertEmbeddings(config)
self.encoder = TransformerEncoder( self.encoder = TransformerEncoder(
num_layers=config.num_hidden_layers, num_layers=config.num_hidden_layers,
dims=config.intermediate_size, dims=config.dim,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
) )
self.pooler = nn.Linear(config.intermediate_size, config.vocab_size) self.pooler = nn.Linear(config.dim, config.dim)
def __call__( def __call__(
self, self,
@ -132,39 +133,25 @@ class Bert(nn.Module):
def load_model(bert_model: str, weights_path: str) -> tuple[Bert, BertTokenizer]: def load_model(bert_model: str, weights_path: str) -> tuple[Bert, BertTokenizer]:
# load the weights npz if not Path(weights_path).exists():
weights = mx.load(weights_path) raise ValueError(f"No model weights found in {weights_path}")
weights = tree_unflatten(list(weights.items()))
# create and update the model # create and update the model
model = Bert(model_configs[bert_model]) model = Bert(model_configs[bert_model])
model.update(weights) model.load_weights(weights_path)
tokenizer = BertTokenizer.from_pretrained(bert_model) tokenizer = BertTokenizer.from_pretrained(bert_model)
return model, tokenizer return model, tokenizer
def run(bert_model: str, mlx_model: str): def run(bert_model: str, mlx_model: str, batch: List[str]):
model, tokenizer = load_model(bert_model, mlx_model) model, tokenizer = load_model(bert_model, mlx_model)
batch = [
"This is an example of BERT working on MLX.",
"A second string",
"This is another string.",
]
tokens = tokenizer(batch, return_tensors="np", padding=True) tokens = tokenizer(batch, return_tensors="np", padding=True)
tokens = {key: mx.array(v) for key, v in tokens.items()} tokens = {key: mx.array(v) for key, v in tokens.items()}
mlx_output, mlx_pooled = model(**tokens) return model(**tokens)
mlx_output = numpy.array(mlx_output)
mlx_pooled = numpy.array(mlx_pooled)
print("MLX BERT:")
print(mlx_output)
print("\n\nMLX Pooled:")
print(mlx_pooled[0, :20])
if __name__ == "__main__": if __name__ == "__main__":
@ -181,6 +168,11 @@ if __name__ == "__main__":
default="weights/bert-base-uncased.npz", default="weights/bert-base-uncased.npz",
help="The path of the stored MLX BERT weights (npz file).", help="The path of the stored MLX BERT weights (npz file).",
) )
parser.add_argument(
"--text",
type=str,
default="This is an example of BERT working in MLX",
help="The text to generate embeddings for.",
)
args = parser.parse_args() args = parser.parse_args()
run(args.bert_model, args.mlx_model, args.text)
run(args.bert_model, args.mlx_model)

34
bert/test.py Normal file
View File

@ -0,0 +1,34 @@
from typing import List
import model
import numpy as np
from transformers import AutoModel, AutoTokenizer
def run_torch(bert_model: str, batch: List[str]):
tokenizer = AutoTokenizer.from_pretrained(bert_model)
torch_model = AutoModel.from_pretrained(bert_model)
torch_tokens = tokenizer(batch, return_tensors="pt", padding=True)
torch_forward = torch_model(**torch_tokens)
torch_output = torch_forward.last_hidden_state.detach().numpy()
torch_pooled = torch_forward.pooler_output.detach().numpy()
return torch_output, torch_pooled
if __name__ == "__main__":
bert_model = "bert-base-uncased"
mlx_model = "weights/bert-base-uncased.npz"
batch = [
"This is an example of BERT working in MLX.",
"A second string",
"This is another string.",
]
torch_output, torch_pooled = run_torch(bert_model, batch)
mlx_output, mlx_pooled = model.run(bert_model, mlx_model, batch)
assert np.allclose(
torch_output, mlx_output, rtol=1e-4, atol=1e-5
), "Model output is different"
assert np.allclose(
torch_pooled, mlx_pooled, rtol=1e-4, atol=1e-5
), "Model pooled output is different"
print("Tests pass :)")