Some fixes / cleanup for BERT example (#269)

* some fixes/cleaning for bert + test

* nit
This commit is contained in:
Awni Hannun
2024-01-09 08:44:51 -08:00
committed by GitHub
parent 6759dfddf1
commit bbd7172eef
4 changed files with 77 additions and 117 deletions

View File

@@ -1,6 +1,7 @@
import argparse
from dataclasses import dataclass
from typing import Optional
from pathlib import Path
from typing import List, Optional
import mlx.core as mx
import mlx.nn as nn
@@ -12,7 +13,7 @@ from transformers import BertTokenizer
@dataclass
class ModelArgs:
intermediate_size: int = 768
dim: int = 768
num_attention_heads: int = 12
num_hidden_layers: int = 12
vocab_size: int = 30522
@@ -26,10 +27,10 @@ model_configs = {
"bert-base-uncased": ModelArgs(),
"bert-base-cased": ModelArgs(),
"bert-large-uncased": ModelArgs(
intermediate_size=1024, num_attention_heads=16, num_hidden_layers=24
dim=1024, num_attention_heads=16, num_hidden_layers=24
),
"bert-large-cased": ModelArgs(
intermediate_size=1024, num_attention_heads=16, num_hidden_layers=24
dim=1024, num_attention_heads=16, num_hidden_layers=24
),
}
@@ -86,12 +87,12 @@ class TransformerEncoder(nn.Module):
class BertEmbeddings(nn.Module):
def __init__(self, config: ModelArgs):
self.word_embeddings = nn.Embedding(config.vocab_size, config.intermediate_size)
self.token_type_embeddings = nn.Embedding(2, config.intermediate_size)
self.word_embeddings = nn.Embedding(config.vocab_size, config.dim)
self.token_type_embeddings = nn.Embedding(2, config.dim)
self.position_embeddings = nn.Embedding(
config.max_position_embeddings, config.intermediate_size
config.max_position_embeddings, config.dim
)
self.norm = nn.LayerNorm(config.intermediate_size, eps=config.layer_norm_eps)
self.norm = nn.LayerNorm(config.dim, eps=config.layer_norm_eps)
def __call__(self, input_ids: mx.array, token_type_ids: mx.array) -> mx.array:
words = self.word_embeddings(input_ids)
@@ -109,10 +110,10 @@ class Bert(nn.Module):
self.embeddings = BertEmbeddings(config)
self.encoder = TransformerEncoder(
num_layers=config.num_hidden_layers,
dims=config.intermediate_size,
dims=config.dim,
num_heads=config.num_attention_heads,
)
self.pooler = nn.Linear(config.intermediate_size, config.vocab_size)
self.pooler = nn.Linear(config.dim, config.dim)
def __call__(
self,
@@ -132,39 +133,25 @@ class Bert(nn.Module):
def load_model(bert_model: str, weights_path: str) -> tuple[Bert, BertTokenizer]:
# load the weights npz
weights = mx.load(weights_path)
weights = tree_unflatten(list(weights.items()))
if not Path(weights_path).exists():
raise ValueError(f"No model weights found in {weights_path}")
# create and update the model
model = Bert(model_configs[bert_model])
model.update(weights)
model.load_weights(weights_path)
tokenizer = BertTokenizer.from_pretrained(bert_model)
return model, tokenizer
def run(bert_model: str, mlx_model: str):
def run(bert_model: str, mlx_model: str, batch: List[str]):
model, tokenizer = load_model(bert_model, mlx_model)
batch = [
"This is an example of BERT working on MLX.",
"A second string",
"This is another string.",
]
tokens = tokenizer(batch, return_tensors="np", padding=True)
tokens = {key: mx.array(v) for key, v in tokens.items()}
mlx_output, mlx_pooled = model(**tokens)
mlx_output = numpy.array(mlx_output)
mlx_pooled = numpy.array(mlx_pooled)
print("MLX BERT:")
print(mlx_output)
print("\n\nMLX Pooled:")
print(mlx_pooled[0, :20])
return model(**tokens)
if __name__ == "__main__":
@@ -181,6 +168,11 @@ if __name__ == "__main__":
default="weights/bert-base-uncased.npz",
help="The path of the stored MLX BERT weights (npz file).",
)
parser.add_argument(
"--text",
type=str,
default="This is an example of BERT working in MLX",
help="The text to generate embeddings for.",
)
args = parser.parse_args()
run(args.bert_model, args.mlx_model)
run(args.bert_model, args.mlx_model, args.text)