mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Some fixes / cleanup for BERT example (#269)
* some fixes/cleaning for bert + test * nit
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import argparse
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -12,7 +13,7 @@ from transformers import BertTokenizer
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
intermediate_size: int = 768
|
||||
dim: int = 768
|
||||
num_attention_heads: int = 12
|
||||
num_hidden_layers: int = 12
|
||||
vocab_size: int = 30522
|
||||
@@ -26,10 +27,10 @@ model_configs = {
|
||||
"bert-base-uncased": ModelArgs(),
|
||||
"bert-base-cased": ModelArgs(),
|
||||
"bert-large-uncased": ModelArgs(
|
||||
intermediate_size=1024, num_attention_heads=16, num_hidden_layers=24
|
||||
dim=1024, num_attention_heads=16, num_hidden_layers=24
|
||||
),
|
||||
"bert-large-cased": ModelArgs(
|
||||
intermediate_size=1024, num_attention_heads=16, num_hidden_layers=24
|
||||
dim=1024, num_attention_heads=16, num_hidden_layers=24
|
||||
),
|
||||
}
|
||||
|
||||
@@ -86,12 +87,12 @@ class TransformerEncoder(nn.Module):
|
||||
|
||||
class BertEmbeddings(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.intermediate_size)
|
||||
self.token_type_embeddings = nn.Embedding(2, config.intermediate_size)
|
||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.dim)
|
||||
self.token_type_embeddings = nn.Embedding(2, config.dim)
|
||||
self.position_embeddings = nn.Embedding(
|
||||
config.max_position_embeddings, config.intermediate_size
|
||||
config.max_position_embeddings, config.dim
|
||||
)
|
||||
self.norm = nn.LayerNorm(config.intermediate_size, eps=config.layer_norm_eps)
|
||||
self.norm = nn.LayerNorm(config.dim, eps=config.layer_norm_eps)
|
||||
|
||||
def __call__(self, input_ids: mx.array, token_type_ids: mx.array) -> mx.array:
|
||||
words = self.word_embeddings(input_ids)
|
||||
@@ -109,10 +110,10 @@ class Bert(nn.Module):
|
||||
self.embeddings = BertEmbeddings(config)
|
||||
self.encoder = TransformerEncoder(
|
||||
num_layers=config.num_hidden_layers,
|
||||
dims=config.intermediate_size,
|
||||
dims=config.dim,
|
||||
num_heads=config.num_attention_heads,
|
||||
)
|
||||
self.pooler = nn.Linear(config.intermediate_size, config.vocab_size)
|
||||
self.pooler = nn.Linear(config.dim, config.dim)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -132,39 +133,25 @@ class Bert(nn.Module):
|
||||
|
||||
|
||||
def load_model(bert_model: str, weights_path: str) -> tuple[Bert, BertTokenizer]:
|
||||
# load the weights npz
|
||||
weights = mx.load(weights_path)
|
||||
weights = tree_unflatten(list(weights.items()))
|
||||
if not Path(weights_path).exists():
|
||||
raise ValueError(f"No model weights found in {weights_path}")
|
||||
|
||||
# create and update the model
|
||||
model = Bert(model_configs[bert_model])
|
||||
model.update(weights)
|
||||
model.load_weights(weights_path)
|
||||
|
||||
tokenizer = BertTokenizer.from_pretrained(bert_model)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def run(bert_model: str, mlx_model: str):
|
||||
def run(bert_model: str, mlx_model: str, batch: List[str]):
|
||||
model, tokenizer = load_model(bert_model, mlx_model)
|
||||
|
||||
batch = [
|
||||
"This is an example of BERT working on MLX.",
|
||||
"A second string",
|
||||
"This is another string.",
|
||||
]
|
||||
|
||||
tokens = tokenizer(batch, return_tensors="np", padding=True)
|
||||
tokens = {key: mx.array(v) for key, v in tokens.items()}
|
||||
|
||||
mlx_output, mlx_pooled = model(**tokens)
|
||||
mlx_output = numpy.array(mlx_output)
|
||||
mlx_pooled = numpy.array(mlx_pooled)
|
||||
|
||||
print("MLX BERT:")
|
||||
print(mlx_output)
|
||||
|
||||
print("\n\nMLX Pooled:")
|
||||
print(mlx_pooled[0, :20])
|
||||
return model(**tokens)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -181,6 +168,11 @@ if __name__ == "__main__":
|
||||
default="weights/bert-base-uncased.npz",
|
||||
help="The path of the stored MLX BERT weights (npz file).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text",
|
||||
type=str,
|
||||
default="This is an example of BERT working in MLX",
|
||||
help="The text to generate embeddings for.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
run(args.bert_model, args.mlx_model)
|
||||
run(args.bert_model, args.mlx_model, args.text)
|
||||
|
Reference in New Issue
Block a user