Some fixes / cleanup for BERT example (#269)

* some fixes/cleaning for bert + test

* nit
This commit is contained in:
Awni Hannun 2024-01-09 08:44:51 -08:00 committed by GitHub
parent 6759dfddf1
commit bbd7172eef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 77 additions and 117 deletions

View File

@ -2,9 +2,15 @@
An implementation of BERT [(Devlin, et al., 2019)](https://aclanthology.org/N19-1423/) within MLX.
## Downloading and Converting Weights
## Setup
The `convert.py` script relies on `transformers` to download the weights, and exports them as a single `.npz` file.
Install the requirements:
```
pip install -r requirements.txt
```
Then convert the weights with:
```
python convert.py \
@ -30,49 +36,20 @@ tokens = {key: mx.array(v) for key, v in tokens.items()}
output, pooled = model(**tokens)
```
The `output` contains a `Batch x Tokens x Dims` tensor, representing a vector for every input token.
If you want to train anything at a **token-level**, you'll want to use this.
The `output` contains a `Batch x Tokens x Dims` tensor, representing a vector
for every input token. If you want to train anything at a **token-level**,
you'll want to use this.
The `pooled` contains a `Batch x Dims` tensor, which is the pooled representation for each input.
If you want to train a **classification** model, you'll want to use this.
The `pooled` contains a `Batch x Dims` tensor, which is the pooled
representation for each input. If you want to train a **classification**
model, you'll want to use this.
## Comparison with 🤗 `transformers` Implementation
In order to run the model, and have it forward inference on a batch of examples:
## Test
You can check the output for the default model (`bert-base-uncased`) matches the
Hugging Face version with:
```sh
python model.py \
--bert-model bert-base-uncased \
--mlx-model weights/bert-base-uncased.npz
```
Which will show the following outputs:
```
MLX BERT:
[[[-0.52508914 -0.1993871 -0.28210318 ... -0.61125606 0.19114694
0.8227601 ]
[-0.8783862 -0.37107834 -0.52238125 ... -0.5067165 1.0847603
0.31066895]
[-0.70010054 -0.5424497 -0.26593682 ... -0.2688697 0.38338926
0.6557663 ]
...
```
They can be compared against the 🤗 implementation with:
```sh
python hf_model.py \
--bert-model bert-base-uncased
```
Which will show:
```
HF BERT:
[[[-0.52508944 -0.1993877 -0.28210333 ... -0.6112575 0.19114678
0.8227603 ]
[-0.878387 -0.371079 -0.522381 ... -0.50671494 1.0847601
0.31066933]
[-0.7001008 -0.5424504 -0.26593733 ... -0.26887015 0.38339025
0.65576553]
...
python test.py
```

View File

@ -1,43 +0,0 @@
import argparse
from transformers import AutoModel, AutoTokenizer
def run(bert_model: str):
batch = [
"This is an example of BERT working on MLX.",
"A second string",
"This is another string.",
]
tokenizer = AutoTokenizer.from_pretrained(bert_model)
torch_model = AutoModel.from_pretrained(bert_model)
torch_tokens = tokenizer(batch, return_tensors="pt", padding=True)
torch_forward = torch_model(**torch_tokens)
torch_output = torch_forward.last_hidden_state.detach().numpy()
torch_pooled = torch_forward.pooler_output.detach().numpy()
print("\n HF BERT:")
print(torch_output)
print("\n\n HF Pooled:")
print(torch_pooled[0, :20])
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run the BERT model using Hugging Face Transformers."
)
parser.add_argument(
"--bert-model",
choices=[
"bert-base-uncased",
"bert-base-cased",
"bert-large-uncased",
"bert-large-cased",
],
default="bert-base-uncased",
help="The huggingface name of the BERT model to save.",
)
args = parser.parse_args()
run(args.bert_model)

View File

@ -1,6 +1,7 @@
import argparse
from dataclasses import dataclass
from typing import Optional
from pathlib import Path
from typing import List, Optional
import mlx.core as mx
import mlx.nn as nn
@ -12,7 +13,7 @@ from transformers import BertTokenizer
@dataclass
class ModelArgs:
intermediate_size: int = 768
dim: int = 768
num_attention_heads: int = 12
num_hidden_layers: int = 12
vocab_size: int = 30522
@ -26,10 +27,10 @@ model_configs = {
"bert-base-uncased": ModelArgs(),
"bert-base-cased": ModelArgs(),
"bert-large-uncased": ModelArgs(
intermediate_size=1024, num_attention_heads=16, num_hidden_layers=24
dim=1024, num_attention_heads=16, num_hidden_layers=24
),
"bert-large-cased": ModelArgs(
intermediate_size=1024, num_attention_heads=16, num_hidden_layers=24
dim=1024, num_attention_heads=16, num_hidden_layers=24
),
}
@ -86,12 +87,12 @@ class TransformerEncoder(nn.Module):
class BertEmbeddings(nn.Module):
def __init__(self, config: ModelArgs):
self.word_embeddings = nn.Embedding(config.vocab_size, config.intermediate_size)
self.token_type_embeddings = nn.Embedding(2, config.intermediate_size)
self.word_embeddings = nn.Embedding(config.vocab_size, config.dim)
self.token_type_embeddings = nn.Embedding(2, config.dim)
self.position_embeddings = nn.Embedding(
config.max_position_embeddings, config.intermediate_size
config.max_position_embeddings, config.dim
)
self.norm = nn.LayerNorm(config.intermediate_size, eps=config.layer_norm_eps)
self.norm = nn.LayerNorm(config.dim, eps=config.layer_norm_eps)
def __call__(self, input_ids: mx.array, token_type_ids: mx.array) -> mx.array:
words = self.word_embeddings(input_ids)
@ -109,10 +110,10 @@ class Bert(nn.Module):
self.embeddings = BertEmbeddings(config)
self.encoder = TransformerEncoder(
num_layers=config.num_hidden_layers,
dims=config.intermediate_size,
dims=config.dim,
num_heads=config.num_attention_heads,
)
self.pooler = nn.Linear(config.intermediate_size, config.vocab_size)
self.pooler = nn.Linear(config.dim, config.dim)
def __call__(
self,
@ -132,39 +133,25 @@ class Bert(nn.Module):
def load_model(bert_model: str, weights_path: str) -> tuple[Bert, BertTokenizer]:
# load the weights npz
weights = mx.load(weights_path)
weights = tree_unflatten(list(weights.items()))
if not Path(weights_path).exists():
raise ValueError(f"No model weights found in {weights_path}")
# create and update the model
model = Bert(model_configs[bert_model])
model.update(weights)
model.load_weights(weights_path)
tokenizer = BertTokenizer.from_pretrained(bert_model)
return model, tokenizer
def run(bert_model: str, mlx_model: str):
def run(bert_model: str, mlx_model: str, batch: List[str]):
model, tokenizer = load_model(bert_model, mlx_model)
batch = [
"This is an example of BERT working on MLX.",
"A second string",
"This is another string.",
]
tokens = tokenizer(batch, return_tensors="np", padding=True)
tokens = {key: mx.array(v) for key, v in tokens.items()}
mlx_output, mlx_pooled = model(**tokens)
mlx_output = numpy.array(mlx_output)
mlx_pooled = numpy.array(mlx_pooled)
print("MLX BERT:")
print(mlx_output)
print("\n\nMLX Pooled:")
print(mlx_pooled[0, :20])
return model(**tokens)
if __name__ == "__main__":
@ -181,6 +168,11 @@ if __name__ == "__main__":
default="weights/bert-base-uncased.npz",
help="The path of the stored MLX BERT weights (npz file).",
)
parser.add_argument(
"--text",
type=str,
default="This is an example of BERT working in MLX",
help="The text to generate embeddings for.",
)
args = parser.parse_args()
run(args.bert_model, args.mlx_model)
run(args.bert_model, args.mlx_model, args.text)

34
bert/test.py Normal file
View File

@ -0,0 +1,34 @@
from typing import List
import model
import numpy as np
from transformers import AutoModel, AutoTokenizer
def run_torch(bert_model: str, batch: List[str]):
tokenizer = AutoTokenizer.from_pretrained(bert_model)
torch_model = AutoModel.from_pretrained(bert_model)
torch_tokens = tokenizer(batch, return_tensors="pt", padding=True)
torch_forward = torch_model(**torch_tokens)
torch_output = torch_forward.last_hidden_state.detach().numpy()
torch_pooled = torch_forward.pooler_output.detach().numpy()
return torch_output, torch_pooled
if __name__ == "__main__":
bert_model = "bert-base-uncased"
mlx_model = "weights/bert-base-uncased.npz"
batch = [
"This is an example of BERT working in MLX.",
"A second string",
"This is another string.",
]
torch_output, torch_pooled = run_torch(bert_model, batch)
mlx_output, mlx_pooled = model.run(bert_model, mlx_model, batch)
assert np.allclose(
torch_output, mlx_output, rtol=1e-4, atol=1e-5
), "Model output is different"
assert np.allclose(
torch_pooled, mlx_pooled, rtol=1e-4, atol=1e-5
), "Model pooled output is different"
print("Tests pass :)")