mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
Enable more BERT models (#580)
* Update convert.py * Update model.py * Update test.py * Update model.py * Update convert.py * Add files via upload * Update convert.py * format * nit * nit --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -8,31 +8,7 @@ import mlx.nn as nn
|
||||
import numpy
|
||||
import numpy as np
|
||||
from mlx.utils import tree_unflatten
|
||||
from transformers import BertTokenizer
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
dim: int = 768
|
||||
num_attention_heads: int = 12
|
||||
num_hidden_layers: int = 12
|
||||
vocab_size: int = 30522
|
||||
attention_probs_dropout_prob: float = 0.1
|
||||
hidden_dropout_prob: float = 0.1
|
||||
layer_norm_eps: float = 1e-12
|
||||
max_position_embeddings: int = 512
|
||||
|
||||
|
||||
model_configs = {
|
||||
"bert-base-uncased": ModelArgs(),
|
||||
"bert-base-cased": ModelArgs(),
|
||||
"bert-large-uncased": ModelArgs(
|
||||
dim=1024, num_attention_heads=16, num_hidden_layers=24
|
||||
),
|
||||
"bert-large-cased": ModelArgs(
|
||||
dim=1024, num_attention_heads=16, num_hidden_layers=24
|
||||
),
|
||||
}
|
||||
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
@@ -86,20 +62,29 @@ class TransformerEncoder(nn.Module):
|
||||
|
||||
|
||||
class BertEmbeddings(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.dim)
|
||||
self.token_type_embeddings = nn.Embedding(2, config.dim)
|
||||
self.position_embeddings = nn.Embedding(
|
||||
config.max_position_embeddings, config.dim
|
||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.token_type_embeddings = nn.Embedding(
|
||||
config.type_vocab_size, config.hidden_size
|
||||
)
|
||||
self.norm = nn.LayerNorm(config.dim, eps=config.layer_norm_eps)
|
||||
self.position_embeddings = nn.Embedding(
|
||||
config.max_position_embeddings, config.hidden_size
|
||||
)
|
||||
self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
def __call__(self, input_ids: mx.array, token_type_ids: mx.array) -> mx.array:
|
||||
def __call__(
|
||||
self, input_ids: mx.array, token_type_ids: mx.array = None
|
||||
) -> mx.array:
|
||||
words = self.word_embeddings(input_ids)
|
||||
position = self.position_embeddings(
|
||||
mx.broadcast_to(mx.arange(input_ids.shape[1]), input_ids.shape)
|
||||
)
|
||||
|
||||
if token_type_ids is None:
|
||||
# If token_type_ids is not provided, default to zeros
|
||||
token_type_ids = mx.zeros_like(input_ids)
|
||||
|
||||
token_types = self.token_type_embeddings(token_type_ids)
|
||||
|
||||
embeddings = position + words + token_types
|
||||
@@ -107,20 +92,21 @@ class BertEmbeddings(nn.Module):
|
||||
|
||||
|
||||
class Bert(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.embeddings = BertEmbeddings(config)
|
||||
self.encoder = TransformerEncoder(
|
||||
num_layers=config.num_hidden_layers,
|
||||
dims=config.dim,
|
||||
dims=config.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
mlp_dims=config.intermediate_size,
|
||||
)
|
||||
self.pooler = nn.Linear(config.dim, config.dim)
|
||||
self.pooler = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids: mx.array,
|
||||
token_type_ids: mx.array,
|
||||
token_type_ids: mx.array = None,
|
||||
attention_mask: mx.array = None,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
x = self.embeddings(input_ids, token_type_ids)
|
||||
@@ -134,15 +120,19 @@ class Bert(nn.Module):
|
||||
return y, mx.tanh(self.pooler(y[:, 0]))
|
||||
|
||||
|
||||
def load_model(bert_model: str, weights_path: str) -> Tuple[Bert, BertTokenizer]:
|
||||
def load_model(
|
||||
bert_model: str, weights_path: str
|
||||
) -> Tuple[Bert, PreTrainedTokenizerBase]:
|
||||
if not Path(weights_path).exists():
|
||||
raise ValueError(f"No model weights found in {weights_path}")
|
||||
|
||||
config = AutoConfig.from_pretrained(bert_model)
|
||||
|
||||
# create and update the model
|
||||
model = Bert(model_configs[bert_model])
|
||||
model = Bert(config)
|
||||
model.load_weights(weights_path)
|
||||
|
||||
tokenizer = BertTokenizer.from_pretrained(bert_model)
|
||||
tokenizer = AutoTokenizer.from_pretrained(bert_model)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
Reference in New Issue
Block a user