Enable more BERT models (#580)

* Update convert.py

* Update model.py

* Update test.py

* Update model.py

* Update convert.py

* Add files via upload

* Update convert.py

* format

* nit

* nit

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
yzimmermann
2024-03-20 01:21:33 +01:00
committed by GitHub
parent b0bcd86a40
commit 4680ef4413
4 changed files with 76 additions and 68 deletions

View File

@@ -8,31 +8,7 @@ import mlx.nn as nn
import numpy
import numpy as np
from mlx.utils import tree_unflatten
from transformers import BertTokenizer
@dataclass
class ModelArgs:
dim: int = 768
num_attention_heads: int = 12
num_hidden_layers: int = 12
vocab_size: int = 30522
attention_probs_dropout_prob: float = 0.1
hidden_dropout_prob: float = 0.1
layer_norm_eps: float = 1e-12
max_position_embeddings: int = 512
model_configs = {
"bert-base-uncased": ModelArgs(),
"bert-base-cased": ModelArgs(),
"bert-large-uncased": ModelArgs(
dim=1024, num_attention_heads=16, num_hidden_layers=24
),
"bert-large-cased": ModelArgs(
dim=1024, num_attention_heads=16, num_hidden_layers=24
),
}
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase
class TransformerEncoderLayer(nn.Module):
@@ -86,20 +62,29 @@ class TransformerEncoder(nn.Module):
class BertEmbeddings(nn.Module):
def __init__(self, config: ModelArgs):
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.dim)
self.token_type_embeddings = nn.Embedding(2, config.dim)
self.position_embeddings = nn.Embedding(
config.max_position_embeddings, config.dim
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.token_type_embeddings = nn.Embedding(
config.type_vocab_size, config.hidden_size
)
self.norm = nn.LayerNorm(config.dim, eps=config.layer_norm_eps)
self.position_embeddings = nn.Embedding(
config.max_position_embeddings, config.hidden_size
)
self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def __call__(self, input_ids: mx.array, token_type_ids: mx.array) -> mx.array:
def __call__(
self, input_ids: mx.array, token_type_ids: mx.array = None
) -> mx.array:
words = self.word_embeddings(input_ids)
position = self.position_embeddings(
mx.broadcast_to(mx.arange(input_ids.shape[1]), input_ids.shape)
)
if token_type_ids is None:
# If token_type_ids is not provided, default to zeros
token_type_ids = mx.zeros_like(input_ids)
token_types = self.token_type_embeddings(token_type_ids)
embeddings = position + words + token_types
@@ -107,20 +92,21 @@ class BertEmbeddings(nn.Module):
class Bert(nn.Module):
def __init__(self, config: ModelArgs):
def __init__(self, config):
super().__init__()
self.embeddings = BertEmbeddings(config)
self.encoder = TransformerEncoder(
num_layers=config.num_hidden_layers,
dims=config.dim,
dims=config.hidden_size,
num_heads=config.num_attention_heads,
mlp_dims=config.intermediate_size,
)
self.pooler = nn.Linear(config.dim, config.dim)
self.pooler = nn.Linear(config.hidden_size, config.hidden_size)
def __call__(
self,
input_ids: mx.array,
token_type_ids: mx.array,
token_type_ids: mx.array = None,
attention_mask: mx.array = None,
) -> Tuple[mx.array, mx.array]:
x = self.embeddings(input_ids, token_type_ids)
@@ -134,15 +120,19 @@ class Bert(nn.Module):
return y, mx.tanh(self.pooler(y[:, 0]))
def load_model(bert_model: str, weights_path: str) -> Tuple[Bert, BertTokenizer]:
def load_model(
bert_model: str, weights_path: str
) -> Tuple[Bert, PreTrainedTokenizerBase]:
if not Path(weights_path).exists():
raise ValueError(f"No model weights found in {weights_path}")
config = AutoConfig.from_pretrained(bert_model)
# create and update the model
model = Bert(model_configs[bert_model])
model = Bert(config)
model.load_weights(weights_path)
tokenizer = BertTokenizer.from_pretrained(bert_model)
tokenizer = AutoTokenizer.from_pretrained(bert_model)
return model, tokenizer