mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Some fixes / cleanup for BERT example (#269)
* some fixes/cleaning for bert + test * nit
This commit is contained in:
parent
6759dfddf1
commit
bbd7172eef
@ -2,9 +2,15 @@
|
|||||||
|
|
||||||
An implementation of BERT [(Devlin, et al., 2019)](https://aclanthology.org/N19-1423/) within MLX.
|
An implementation of BERT [(Devlin, et al., 2019)](https://aclanthology.org/N19-1423/) within MLX.
|
||||||
|
|
||||||
## Downloading and Converting Weights
|
## Setup
|
||||||
|
|
||||||
The `convert.py` script relies on `transformers` to download the weights, and exports them as a single `.npz` file.
|
Install the requirements:
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
Then convert the weights with:
|
||||||
|
|
||||||
```
|
```
|
||||||
python convert.py \
|
python convert.py \
|
||||||
@ -30,49 +36,20 @@ tokens = {key: mx.array(v) for key, v in tokens.items()}
|
|||||||
output, pooled = model(**tokens)
|
output, pooled = model(**tokens)
|
||||||
```
|
```
|
||||||
|
|
||||||
The `output` contains a `Batch x Tokens x Dims` tensor, representing a vector for every input token.
|
The `output` contains a `Batch x Tokens x Dims` tensor, representing a vector
|
||||||
If you want to train anything at a **token-level**, you'll want to use this.
|
for every input token. If you want to train anything at a **token-level**,
|
||||||
|
you'll want to use this.
|
||||||
|
|
||||||
The `pooled` contains a `Batch x Dims` tensor, which is the pooled representation for each input.
|
The `pooled` contains a `Batch x Dims` tensor, which is the pooled
|
||||||
If you want to train a **classification** model, you'll want to use this.
|
representation for each input. If you want to train a **classification**
|
||||||
|
model, you'll want to use this.
|
||||||
|
|
||||||
## Comparison with 🤗 `transformers` Implementation
|
|
||||||
|
|
||||||
In order to run the model, and have it forward inference on a batch of examples:
|
## Test
|
||||||
|
|
||||||
|
You can check the output for the default model (`bert-base-uncased`) matches the
|
||||||
|
Hugging Face version with:
|
||||||
|
|
||||||
```sh
|
|
||||||
python model.py \
|
|
||||||
--bert-model bert-base-uncased \
|
|
||||||
--mlx-model weights/bert-base-uncased.npz
|
|
||||||
```
|
```
|
||||||
|
python test.py
|
||||||
Which will show the following outputs:
|
|
||||||
```
|
|
||||||
MLX BERT:
|
|
||||||
[[[-0.52508914 -0.1993871 -0.28210318 ... -0.61125606 0.19114694
|
|
||||||
0.8227601 ]
|
|
||||||
[-0.8783862 -0.37107834 -0.52238125 ... -0.5067165 1.0847603
|
|
||||||
0.31066895]
|
|
||||||
[-0.70010054 -0.5424497 -0.26593682 ... -0.2688697 0.38338926
|
|
||||||
0.6557663 ]
|
|
||||||
...
|
|
||||||
```
|
|
||||||
|
|
||||||
They can be compared against the 🤗 implementation with:
|
|
||||||
|
|
||||||
```sh
|
|
||||||
python hf_model.py \
|
|
||||||
--bert-model bert-base-uncased
|
|
||||||
```
|
|
||||||
|
|
||||||
Which will show:
|
|
||||||
```
|
|
||||||
HF BERT:
|
|
||||||
[[[-0.52508944 -0.1993877 -0.28210333 ... -0.6112575 0.19114678
|
|
||||||
0.8227603 ]
|
|
||||||
[-0.878387 -0.371079 -0.522381 ... -0.50671494 1.0847601
|
|
||||||
0.31066933]
|
|
||||||
[-0.7001008 -0.5424504 -0.26593733 ... -0.26887015 0.38339025
|
|
||||||
0.65576553]
|
|
||||||
...
|
|
||||||
```
|
```
|
||||||
|
@ -1,43 +0,0 @@
|
|||||||
import argparse
|
|
||||||
|
|
||||||
from transformers import AutoModel, AutoTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
def run(bert_model: str):
|
|
||||||
batch = [
|
|
||||||
"This is an example of BERT working on MLX.",
|
|
||||||
"A second string",
|
|
||||||
"This is another string.",
|
|
||||||
]
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(bert_model)
|
|
||||||
torch_model = AutoModel.from_pretrained(bert_model)
|
|
||||||
torch_tokens = tokenizer(batch, return_tensors="pt", padding=True)
|
|
||||||
torch_forward = torch_model(**torch_tokens)
|
|
||||||
torch_output = torch_forward.last_hidden_state.detach().numpy()
|
|
||||||
torch_pooled = torch_forward.pooler_output.detach().numpy()
|
|
||||||
|
|
||||||
print("\n HF BERT:")
|
|
||||||
print(torch_output)
|
|
||||||
print("\n\n HF Pooled:")
|
|
||||||
print(torch_pooled[0, :20])
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Run the BERT model using Hugging Face Transformers."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--bert-model",
|
|
||||||
choices=[
|
|
||||||
"bert-base-uncased",
|
|
||||||
"bert-base-cased",
|
|
||||||
"bert-large-uncased",
|
|
||||||
"bert-large-cased",
|
|
||||||
],
|
|
||||||
default="bert-base-uncased",
|
|
||||||
help="The huggingface name of the BERT model to save.",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
run(args.bert_model)
|
|
@ -1,6 +1,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from pathlib import Path
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -12,7 +13,7 @@ from transformers import BertTokenizer
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelArgs:
|
class ModelArgs:
|
||||||
intermediate_size: int = 768
|
dim: int = 768
|
||||||
num_attention_heads: int = 12
|
num_attention_heads: int = 12
|
||||||
num_hidden_layers: int = 12
|
num_hidden_layers: int = 12
|
||||||
vocab_size: int = 30522
|
vocab_size: int = 30522
|
||||||
@ -26,10 +27,10 @@ model_configs = {
|
|||||||
"bert-base-uncased": ModelArgs(),
|
"bert-base-uncased": ModelArgs(),
|
||||||
"bert-base-cased": ModelArgs(),
|
"bert-base-cased": ModelArgs(),
|
||||||
"bert-large-uncased": ModelArgs(
|
"bert-large-uncased": ModelArgs(
|
||||||
intermediate_size=1024, num_attention_heads=16, num_hidden_layers=24
|
dim=1024, num_attention_heads=16, num_hidden_layers=24
|
||||||
),
|
),
|
||||||
"bert-large-cased": ModelArgs(
|
"bert-large-cased": ModelArgs(
|
||||||
intermediate_size=1024, num_attention_heads=16, num_hidden_layers=24
|
dim=1024, num_attention_heads=16, num_hidden_layers=24
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -86,12 +87,12 @@ class TransformerEncoder(nn.Module):
|
|||||||
|
|
||||||
class BertEmbeddings(nn.Module):
|
class BertEmbeddings(nn.Module):
|
||||||
def __init__(self, config: ModelArgs):
|
def __init__(self, config: ModelArgs):
|
||||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.intermediate_size)
|
self.word_embeddings = nn.Embedding(config.vocab_size, config.dim)
|
||||||
self.token_type_embeddings = nn.Embedding(2, config.intermediate_size)
|
self.token_type_embeddings = nn.Embedding(2, config.dim)
|
||||||
self.position_embeddings = nn.Embedding(
|
self.position_embeddings = nn.Embedding(
|
||||||
config.max_position_embeddings, config.intermediate_size
|
config.max_position_embeddings, config.dim
|
||||||
)
|
)
|
||||||
self.norm = nn.LayerNorm(config.intermediate_size, eps=config.layer_norm_eps)
|
self.norm = nn.LayerNorm(config.dim, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
def __call__(self, input_ids: mx.array, token_type_ids: mx.array) -> mx.array:
|
def __call__(self, input_ids: mx.array, token_type_ids: mx.array) -> mx.array:
|
||||||
words = self.word_embeddings(input_ids)
|
words = self.word_embeddings(input_ids)
|
||||||
@ -109,10 +110,10 @@ class Bert(nn.Module):
|
|||||||
self.embeddings = BertEmbeddings(config)
|
self.embeddings = BertEmbeddings(config)
|
||||||
self.encoder = TransformerEncoder(
|
self.encoder = TransformerEncoder(
|
||||||
num_layers=config.num_hidden_layers,
|
num_layers=config.num_hidden_layers,
|
||||||
dims=config.intermediate_size,
|
dims=config.dim,
|
||||||
num_heads=config.num_attention_heads,
|
num_heads=config.num_attention_heads,
|
||||||
)
|
)
|
||||||
self.pooler = nn.Linear(config.intermediate_size, config.vocab_size)
|
self.pooler = nn.Linear(config.dim, config.dim)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@ -132,39 +133,25 @@ class Bert(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def load_model(bert_model: str, weights_path: str) -> tuple[Bert, BertTokenizer]:
|
def load_model(bert_model: str, weights_path: str) -> tuple[Bert, BertTokenizer]:
|
||||||
# load the weights npz
|
if not Path(weights_path).exists():
|
||||||
weights = mx.load(weights_path)
|
raise ValueError(f"No model weights found in {weights_path}")
|
||||||
weights = tree_unflatten(list(weights.items()))
|
|
||||||
# create and update the model
|
# create and update the model
|
||||||
model = Bert(model_configs[bert_model])
|
model = Bert(model_configs[bert_model])
|
||||||
model.update(weights)
|
model.load_weights(weights_path)
|
||||||
|
|
||||||
tokenizer = BertTokenizer.from_pretrained(bert_model)
|
tokenizer = BertTokenizer.from_pretrained(bert_model)
|
||||||
|
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
def run(bert_model: str, mlx_model: str):
|
def run(bert_model: str, mlx_model: str, batch: List[str]):
|
||||||
model, tokenizer = load_model(bert_model, mlx_model)
|
model, tokenizer = load_model(bert_model, mlx_model)
|
||||||
|
|
||||||
batch = [
|
|
||||||
"This is an example of BERT working on MLX.",
|
|
||||||
"A second string",
|
|
||||||
"This is another string.",
|
|
||||||
]
|
|
||||||
|
|
||||||
tokens = tokenizer(batch, return_tensors="np", padding=True)
|
tokens = tokenizer(batch, return_tensors="np", padding=True)
|
||||||
tokens = {key: mx.array(v) for key, v in tokens.items()}
|
tokens = {key: mx.array(v) for key, v in tokens.items()}
|
||||||
|
|
||||||
mlx_output, mlx_pooled = model(**tokens)
|
return model(**tokens)
|
||||||
mlx_output = numpy.array(mlx_output)
|
|
||||||
mlx_pooled = numpy.array(mlx_pooled)
|
|
||||||
|
|
||||||
print("MLX BERT:")
|
|
||||||
print(mlx_output)
|
|
||||||
|
|
||||||
print("\n\nMLX Pooled:")
|
|
||||||
print(mlx_pooled[0, :20])
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@ -181,6 +168,11 @@ if __name__ == "__main__":
|
|||||||
default="weights/bert-base-uncased.npz",
|
default="weights/bert-base-uncased.npz",
|
||||||
help="The path of the stored MLX BERT weights (npz file).",
|
help="The path of the stored MLX BERT weights (npz file).",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--text",
|
||||||
|
type=str,
|
||||||
|
default="This is an example of BERT working in MLX",
|
||||||
|
help="The text to generate embeddings for.",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
run(args.bert_model, args.mlx_model, args.text)
|
||||||
run(args.bert_model, args.mlx_model)
|
|
||||||
|
34
bert/test.py
Normal file
34
bert/test.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
import model
|
||||||
|
import numpy as np
|
||||||
|
from transformers import AutoModel, AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def run_torch(bert_model: str, batch: List[str]):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(bert_model)
|
||||||
|
torch_model = AutoModel.from_pretrained(bert_model)
|
||||||
|
torch_tokens = tokenizer(batch, return_tensors="pt", padding=True)
|
||||||
|
torch_forward = torch_model(**torch_tokens)
|
||||||
|
torch_output = torch_forward.last_hidden_state.detach().numpy()
|
||||||
|
torch_pooled = torch_forward.pooler_output.detach().numpy()
|
||||||
|
return torch_output, torch_pooled
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
bert_model = "bert-base-uncased"
|
||||||
|
mlx_model = "weights/bert-base-uncased.npz"
|
||||||
|
batch = [
|
||||||
|
"This is an example of BERT working in MLX.",
|
||||||
|
"A second string",
|
||||||
|
"This is another string.",
|
||||||
|
]
|
||||||
|
torch_output, torch_pooled = run_torch(bert_model, batch)
|
||||||
|
mlx_output, mlx_pooled = model.run(bert_model, mlx_model, batch)
|
||||||
|
assert np.allclose(
|
||||||
|
torch_output, mlx_output, rtol=1e-4, atol=1e-5
|
||||||
|
), "Model output is different"
|
||||||
|
assert np.allclose(
|
||||||
|
torch_pooled, mlx_pooled, rtol=1e-4, atol=1e-5
|
||||||
|
), "Model pooled output is different"
|
||||||
|
print("Tests pass :)")
|
Loading…
Reference in New Issue
Block a user