mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Enable more BERT models (#580)
* Update convert.py * Update model.py * Update test.py * Update model.py * Update convert.py * Add files via upload * Update convert.py * format * nit * nit --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
b0bcd86a40
commit
4680ef4413
@ -1,6 +1,6 @@
|
|||||||
# BERT
|
# BERT
|
||||||
|
|
||||||
An implementation of BERT [(Devlin, et al., 2019)](https://aclanthology.org/N19-1423/) within MLX.
|
An implementation of BERT [(Devlin, et al., 2019)](https://aclanthology.org/N19-1423/) in MLX.
|
||||||
|
|
||||||
## Setup
|
## Setup
|
||||||
|
|
||||||
@ -38,12 +38,12 @@ output, pooled = model(**tokens)
|
|||||||
```
|
```
|
||||||
|
|
||||||
The `output` contains a `Batch x Tokens x Dims` tensor, representing a vector
|
The `output` contains a `Batch x Tokens x Dims` tensor, representing a vector
|
||||||
for every input token. If you want to train anything at a **token-level**,
|
for every input token. If you want to train anything at the **token-level**,
|
||||||
you'll want to use this.
|
use this.
|
||||||
|
|
||||||
The `pooled` contains a `Batch x Dims` tensor, which is the pooled
|
The `pooled` contains a `Batch x Dims` tensor, which is the pooled
|
||||||
representation for each input. If you want to train a **classification**
|
representation for each input. If you want to train a **classification**
|
||||||
model, you'll want to use this.
|
model, use this.
|
||||||
|
|
||||||
|
|
||||||
## Test
|
## Test
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
from transformers import BertModel
|
from transformers import AutoModel
|
||||||
|
|
||||||
|
|
||||||
def replace_key(key: str) -> str:
|
def replace_key(key: str) -> str:
|
||||||
@ -20,7 +20,7 @@ def replace_key(key: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def convert(bert_model: str, mlx_model: str) -> None:
|
def convert(bert_model: str, mlx_model: str) -> None:
|
||||||
model = BertModel.from_pretrained(bert_model)
|
model = AutoModel.from_pretrained(bert_model)
|
||||||
# save the tensors
|
# save the tensors
|
||||||
tensors = {
|
tensors = {
|
||||||
replace_key(key): tensor.numpy() for key, tensor in model.state_dict().items()
|
replace_key(key): tensor.numpy() for key, tensor in model.state_dict().items()
|
||||||
@ -32,14 +32,9 @@ if __name__ == "__main__":
|
|||||||
parser = argparse.ArgumentParser(description="Convert BERT weights to MLX.")
|
parser = argparse.ArgumentParser(description="Convert BERT weights to MLX.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bert-model",
|
"--bert-model",
|
||||||
choices=[
|
type=str,
|
||||||
"bert-base-uncased",
|
|
||||||
"bert-base-cased",
|
|
||||||
"bert-large-uncased",
|
|
||||||
"bert-large-cased",
|
|
||||||
],
|
|
||||||
default="bert-base-uncased",
|
default="bert-base-uncased",
|
||||||
help="The huggingface name of the BERT model to save.",
|
help="The huggingface name of the BERT model to save. Any BERT-like model can be specified.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--mlx-model",
|
"--mlx-model",
|
||||||
|
@ -8,31 +8,7 @@ import mlx.nn as nn
|
|||||||
import numpy
|
import numpy
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mlx.utils import tree_unflatten
|
from mlx.utils import tree_unflatten
|
||||||
from transformers import BertTokenizer
|
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ModelArgs:
|
|
||||||
dim: int = 768
|
|
||||||
num_attention_heads: int = 12
|
|
||||||
num_hidden_layers: int = 12
|
|
||||||
vocab_size: int = 30522
|
|
||||||
attention_probs_dropout_prob: float = 0.1
|
|
||||||
hidden_dropout_prob: float = 0.1
|
|
||||||
layer_norm_eps: float = 1e-12
|
|
||||||
max_position_embeddings: int = 512
|
|
||||||
|
|
||||||
|
|
||||||
model_configs = {
|
|
||||||
"bert-base-uncased": ModelArgs(),
|
|
||||||
"bert-base-cased": ModelArgs(),
|
|
||||||
"bert-large-uncased": ModelArgs(
|
|
||||||
dim=1024, num_attention_heads=16, num_hidden_layers=24
|
|
||||||
),
|
|
||||||
"bert-large-cased": ModelArgs(
|
|
||||||
dim=1024, num_attention_heads=16, num_hidden_layers=24
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerEncoderLayer(nn.Module):
|
class TransformerEncoderLayer(nn.Module):
|
||||||
@ -86,20 +62,29 @@ class TransformerEncoder(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class BertEmbeddings(nn.Module):
|
class BertEmbeddings(nn.Module):
|
||||||
def __init__(self, config: ModelArgs):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.dim)
|
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||||
self.token_type_embeddings = nn.Embedding(2, config.dim)
|
self.token_type_embeddings = nn.Embedding(
|
||||||
self.position_embeddings = nn.Embedding(
|
config.type_vocab_size, config.hidden_size
|
||||||
config.max_position_embeddings, config.dim
|
|
||||||
)
|
)
|
||||||
self.norm = nn.LayerNorm(config.dim, eps=config.layer_norm_eps)
|
self.position_embeddings = nn.Embedding(
|
||||||
|
config.max_position_embeddings, config.hidden_size
|
||||||
|
)
|
||||||
|
self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
def __call__(self, input_ids: mx.array, token_type_ids: mx.array) -> mx.array:
|
def __call__(
|
||||||
|
self, input_ids: mx.array, token_type_ids: mx.array = None
|
||||||
|
) -> mx.array:
|
||||||
words = self.word_embeddings(input_ids)
|
words = self.word_embeddings(input_ids)
|
||||||
position = self.position_embeddings(
|
position = self.position_embeddings(
|
||||||
mx.broadcast_to(mx.arange(input_ids.shape[1]), input_ids.shape)
|
mx.broadcast_to(mx.arange(input_ids.shape[1]), input_ids.shape)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if token_type_ids is None:
|
||||||
|
# If token_type_ids is not provided, default to zeros
|
||||||
|
token_type_ids = mx.zeros_like(input_ids)
|
||||||
|
|
||||||
token_types = self.token_type_embeddings(token_type_ids)
|
token_types = self.token_type_embeddings(token_type_ids)
|
||||||
|
|
||||||
embeddings = position + words + token_types
|
embeddings = position + words + token_types
|
||||||
@ -107,20 +92,21 @@ class BertEmbeddings(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Bert(nn.Module):
|
class Bert(nn.Module):
|
||||||
def __init__(self, config: ModelArgs):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embeddings = BertEmbeddings(config)
|
self.embeddings = BertEmbeddings(config)
|
||||||
self.encoder = TransformerEncoder(
|
self.encoder = TransformerEncoder(
|
||||||
num_layers=config.num_hidden_layers,
|
num_layers=config.num_hidden_layers,
|
||||||
dims=config.dim,
|
dims=config.hidden_size,
|
||||||
num_heads=config.num_attention_heads,
|
num_heads=config.num_attention_heads,
|
||||||
|
mlp_dims=config.intermediate_size,
|
||||||
)
|
)
|
||||||
self.pooler = nn.Linear(config.dim, config.dim)
|
self.pooler = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
input_ids: mx.array,
|
input_ids: mx.array,
|
||||||
token_type_ids: mx.array,
|
token_type_ids: mx.array = None,
|
||||||
attention_mask: mx.array = None,
|
attention_mask: mx.array = None,
|
||||||
) -> Tuple[mx.array, mx.array]:
|
) -> Tuple[mx.array, mx.array]:
|
||||||
x = self.embeddings(input_ids, token_type_ids)
|
x = self.embeddings(input_ids, token_type_ids)
|
||||||
@ -134,15 +120,19 @@ class Bert(nn.Module):
|
|||||||
return y, mx.tanh(self.pooler(y[:, 0]))
|
return y, mx.tanh(self.pooler(y[:, 0]))
|
||||||
|
|
||||||
|
|
||||||
def load_model(bert_model: str, weights_path: str) -> Tuple[Bert, BertTokenizer]:
|
def load_model(
|
||||||
|
bert_model: str, weights_path: str
|
||||||
|
) -> Tuple[Bert, PreTrainedTokenizerBase]:
|
||||||
if not Path(weights_path).exists():
|
if not Path(weights_path).exists():
|
||||||
raise ValueError(f"No model weights found in {weights_path}")
|
raise ValueError(f"No model weights found in {weights_path}")
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(bert_model)
|
||||||
|
|
||||||
# create and update the model
|
# create and update the model
|
||||||
model = Bert(model_configs[bert_model])
|
model = Bert(config)
|
||||||
model.load_weights(weights_path)
|
model.load_weights(weights_path)
|
||||||
|
|
||||||
tokenizer = BertTokenizer.from_pretrained(bert_model)
|
tokenizer = AutoTokenizer.from_pretrained(bert_model)
|
||||||
|
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
55
bert/test.py
55
bert/test.py
@ -1,3 +1,4 @@
|
|||||||
|
import argparse
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import model
|
import model
|
||||||
@ -16,19 +17,41 @@ def run_torch(bert_model: str, batch: List[str]):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
bert_model = "bert-base-uncased"
|
parser = argparse.ArgumentParser(
|
||||||
mlx_model = "weights/bert-base-uncased.npz"
|
description="Run a BERT-like model for a batch of text."
|
||||||
batch = [
|
)
|
||||||
"This is an example of BERT working in MLX.",
|
parser.add_argument(
|
||||||
"A second string",
|
"--bert-model",
|
||||||
"This is another string.",
|
type=str,
|
||||||
]
|
default="bert-base-uncased",
|
||||||
torch_output, torch_pooled = run_torch(bert_model, batch)
|
help="The model identifier for a BERT-like model from Hugging Face Transformers.",
|
||||||
mlx_output, mlx_pooled = model.run(bert_model, mlx_model, batch)
|
)
|
||||||
assert np.allclose(
|
parser.add_argument(
|
||||||
torch_output, mlx_output, rtol=1e-4, atol=1e-5
|
"--mlx-model",
|
||||||
), "Model output is different"
|
type=str,
|
||||||
assert np.allclose(
|
default="weights/bert-base-uncased.npz",
|
||||||
torch_pooled, mlx_pooled, rtol=1e-4, atol=1e-5
|
help="The path of the stored MLX BERT weights (npz file).",
|
||||||
), "Model pooled output is different"
|
)
|
||||||
print("Tests pass :)")
|
parser.add_argument(
|
||||||
|
"--text",
|
||||||
|
nargs="+",
|
||||||
|
default=["This is an example of BERT working in MLX."],
|
||||||
|
help="A batch of texts to process. Multiple texts should be separated by spaces.",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
torch_output, torch_pooled = run_torch(args.bert_model, args.text)
|
||||||
|
|
||||||
|
mlx_output, mlx_pooled = model.run(args.bert_model, args.mlx_model, args.text)
|
||||||
|
|
||||||
|
if torch_pooled is not None and mlx_pooled is not None:
|
||||||
|
assert np.allclose(
|
||||||
|
torch_output, mlx_output, rtol=1e-4, atol=1e-5
|
||||||
|
), "Model output is different"
|
||||||
|
assert np.allclose(
|
||||||
|
torch_pooled, mlx_pooled, rtol=1e-4, atol=1e-5
|
||||||
|
), "Model pooled output is different"
|
||||||
|
print("Tests pass :)")
|
||||||
|
else:
|
||||||
|
print("Pooled outputs were not compared due to one or both being None.")
|
||||||
|
Loading…
Reference in New Issue
Block a user