2023-12-21 02:22:25 +08:00
|
|
|
import argparse
|
2024-01-10 00:44:51 +08:00
|
|
|
from pathlib import Path
|
2024-01-13 03:15:09 +08:00
|
|
|
from typing import List, Optional, Tuple
|
2023-12-08 18:14:11 +08:00
|
|
|
|
|
|
|
import mlx.core as mx
|
|
|
|
import mlx.nn as nn
|
2023-12-21 02:22:25 +08:00
|
|
|
from mlx.utils import tree_unflatten
|
2024-03-20 08:21:33 +08:00
|
|
|
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase
|
2023-12-08 18:14:11 +08:00
|
|
|
|
|
|
|
|
2023-12-09 23:41:15 +08:00
|
|
|
class TransformerEncoderLayer(nn.Module):
|
2023-12-08 18:14:11 +08:00
|
|
|
"""
|
|
|
|
A transformer encoder layer with (the original BERT) post-normalization.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
dims: int,
|
|
|
|
num_heads: int,
|
|
|
|
mlp_dims: Optional[int] = None,
|
|
|
|
layer_norm_eps: float = 1e-12,
|
|
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
mlp_dims = mlp_dims or dims * 4
|
2023-12-10 01:07:33 +08:00
|
|
|
self.attention = nn.MultiHeadAttention(dims, num_heads, bias=True)
|
2023-12-09 23:41:15 +08:00
|
|
|
self.ln1 = nn.LayerNorm(dims, eps=layer_norm_eps)
|
|
|
|
self.ln2 = nn.LayerNorm(dims, eps=layer_norm_eps)
|
|
|
|
self.linear1 = nn.Linear(dims, mlp_dims)
|
|
|
|
self.linear2 = nn.Linear(mlp_dims, dims)
|
2023-12-08 18:14:11 +08:00
|
|
|
self.gelu = nn.GELU()
|
|
|
|
|
|
|
|
def __call__(self, x, mask):
|
|
|
|
attention_out = self.attention(x, x, x, mask)
|
|
|
|
add_and_norm = self.ln1(x + attention_out)
|
|
|
|
|
|
|
|
ff = self.linear1(add_and_norm)
|
|
|
|
ff_gelu = self.gelu(ff)
|
|
|
|
ff_out = self.linear2(ff_gelu)
|
|
|
|
x = self.ln2(ff_out + add_and_norm)
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
2023-12-09 23:41:15 +08:00
|
|
|
class TransformerEncoder(nn.Module):
|
2023-12-08 18:14:11 +08:00
|
|
|
def __init__(
|
|
|
|
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None
|
|
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
self.layers = [
|
|
|
|
TransformerEncoderLayer(dims, num_heads, mlp_dims)
|
|
|
|
for i in range(num_layers)
|
|
|
|
]
|
|
|
|
|
|
|
|
def __call__(self, x, mask):
|
2023-12-09 23:41:15 +08:00
|
|
|
for layer in self.layers:
|
|
|
|
x = layer(x, mask)
|
2023-12-08 18:14:11 +08:00
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
class BertEmbeddings(nn.Module):
|
2024-03-20 08:21:33 +08:00
|
|
|
def __init__(self, config):
|
2024-03-14 01:24:21 +08:00
|
|
|
super().__init__()
|
2024-03-20 08:21:33 +08:00
|
|
|
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
|
|
|
|
self.token_type_embeddings = nn.Embedding(
|
|
|
|
config.type_vocab_size, config.hidden_size
|
|
|
|
)
|
2023-12-08 18:14:11 +08:00
|
|
|
self.position_embeddings = nn.Embedding(
|
2024-03-20 08:21:33 +08:00
|
|
|
config.max_position_embeddings, config.hidden_size
|
2023-12-08 18:14:11 +08:00
|
|
|
)
|
2024-03-20 08:21:33 +08:00
|
|
|
self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
2023-12-08 18:14:11 +08:00
|
|
|
|
2024-03-20 08:21:33 +08:00
|
|
|
def __call__(
|
|
|
|
self, input_ids: mx.array, token_type_ids: mx.array = None
|
|
|
|
) -> mx.array:
|
2023-12-08 18:14:11 +08:00
|
|
|
words = self.word_embeddings(input_ids)
|
|
|
|
position = self.position_embeddings(
|
|
|
|
mx.broadcast_to(mx.arange(input_ids.shape[1]), input_ids.shape)
|
|
|
|
)
|
2024-03-20 08:21:33 +08:00
|
|
|
|
|
|
|
if token_type_ids is None:
|
|
|
|
# If token_type_ids is not provided, default to zeros
|
|
|
|
token_type_ids = mx.zeros_like(input_ids)
|
|
|
|
|
2023-12-08 18:14:11 +08:00
|
|
|
token_types = self.token_type_embeddings(token_type_ids)
|
|
|
|
|
|
|
|
embeddings = position + words + token_types
|
|
|
|
return self.norm(embeddings)
|
|
|
|
|
|
|
|
|
|
|
|
class Bert(nn.Module):
|
2024-03-20 08:21:33 +08:00
|
|
|
def __init__(self, config):
|
2024-03-14 01:24:21 +08:00
|
|
|
super().__init__()
|
2023-12-08 18:14:11 +08:00
|
|
|
self.embeddings = BertEmbeddings(config)
|
|
|
|
self.encoder = TransformerEncoder(
|
|
|
|
num_layers=config.num_hidden_layers,
|
2024-03-20 08:21:33 +08:00
|
|
|
dims=config.hidden_size,
|
2023-12-08 18:14:11 +08:00
|
|
|
num_heads=config.num_attention_heads,
|
2024-03-20 08:21:33 +08:00
|
|
|
mlp_dims=config.intermediate_size,
|
2023-12-08 18:14:11 +08:00
|
|
|
)
|
2024-03-20 08:21:33 +08:00
|
|
|
self.pooler = nn.Linear(config.hidden_size, config.hidden_size)
|
2023-12-08 18:14:11 +08:00
|
|
|
|
|
|
|
def __call__(
|
|
|
|
self,
|
|
|
|
input_ids: mx.array,
|
2024-03-20 08:21:33 +08:00
|
|
|
token_type_ids: mx.array = None,
|
2023-12-10 01:07:33 +08:00
|
|
|
attention_mask: mx.array = None,
|
2024-01-13 03:15:09 +08:00
|
|
|
) -> Tuple[mx.array, mx.array]:
|
2023-12-08 18:14:11 +08:00
|
|
|
x = self.embeddings(input_ids, token_type_ids)
|
2023-12-10 01:07:33 +08:00
|
|
|
|
|
|
|
if attention_mask is not None:
|
|
|
|
# convert 0's to -infs, 1's to 0's, and make it broadcastable
|
2023-12-10 10:21:24 +08:00
|
|
|
attention_mask = mx.log(attention_mask)
|
2023-12-10 01:07:33 +08:00
|
|
|
attention_mask = mx.expand_dims(attention_mask, (1, 2))
|
|
|
|
|
2023-12-08 18:14:11 +08:00
|
|
|
y = self.encoder(x, attention_mask)
|
|
|
|
return y, mx.tanh(self.pooler(y[:, 0]))
|
|
|
|
|
|
|
|
|
2024-03-20 08:21:33 +08:00
|
|
|
def load_model(
|
|
|
|
bert_model: str, weights_path: str
|
|
|
|
) -> Tuple[Bert, PreTrainedTokenizerBase]:
|
2024-01-10 00:44:51 +08:00
|
|
|
if not Path(weights_path).exists():
|
|
|
|
raise ValueError(f"No model weights found in {weights_path}")
|
|
|
|
|
2024-03-20 08:21:33 +08:00
|
|
|
config = AutoConfig.from_pretrained(bert_model)
|
|
|
|
|
2023-12-09 23:41:15 +08:00
|
|
|
# create and update the model
|
2024-03-20 08:21:33 +08:00
|
|
|
model = Bert(config)
|
2024-01-10 00:44:51 +08:00
|
|
|
model.load_weights(weights_path)
|
2023-12-09 23:41:15 +08:00
|
|
|
|
2024-03-20 08:21:33 +08:00
|
|
|
tokenizer = AutoTokenizer.from_pretrained(bert_model)
|
2023-12-09 23:41:15 +08:00
|
|
|
|
|
|
|
return model, tokenizer
|
|
|
|
|
|
|
|
|
2024-01-10 00:44:51 +08:00
|
|
|
def run(bert_model: str, mlx_model: str, batch: List[str]):
|
2023-12-09 23:41:15 +08:00
|
|
|
model, tokenizer = load_model(bert_model, mlx_model)
|
|
|
|
|
2023-12-08 18:14:11 +08:00
|
|
|
tokens = tokenizer(batch, return_tensors="np", padding=True)
|
|
|
|
tokens = {key: mx.array(v) for key, v in tokens.items()}
|
|
|
|
|
2024-01-10 00:44:51 +08:00
|
|
|
return model(**tokens)
|
2023-12-08 18:14:11 +08:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2023-12-09 23:41:15 +08:00
|
|
|
parser = argparse.ArgumentParser(description="Run the BERT model using MLX.")
|
2023-12-08 18:14:11 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--bert-model",
|
|
|
|
type=str,
|
|
|
|
default="bert-base-uncased",
|
|
|
|
help="The huggingface name of the BERT model to save.",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--mlx-model",
|
|
|
|
type=str,
|
|
|
|
default="weights/bert-base-uncased.npz",
|
2023-12-09 23:41:15 +08:00
|
|
|
help="The path of the stored MLX BERT weights (npz file).",
|
2023-12-08 18:14:11 +08:00
|
|
|
)
|
2024-01-10 00:44:51 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--text",
|
|
|
|
type=str,
|
|
|
|
default="This is an example of BERT working in MLX",
|
|
|
|
help="The text to generate embeddings for.",
|
|
|
|
)
|
2023-12-08 18:14:11 +08:00
|
|
|
args = parser.parse_args()
|
2024-01-10 00:44:51 +08:00
|
|
|
run(args.bert_model, args.mlx_model, args.text)
|