mlx-examples/bert/model.py

246 lines
7.7 KiB
Python
Raw Normal View History

2023-12-08 18:14:11 +08:00
from typing import Optional
from dataclasses import dataclass
from mlx.utils import tree_unflatten, tree_map
from mlx.nn.layers.base import Module
from mlx.nn.layers.linear import Linear
from mlx.nn.layers.normalization import LayerNorm
from transformers import AutoTokenizer
import mlx.core as mx
import mlx.nn as nn
import argparse
import numpy
import math
@dataclass
class ModelArgs:
intermediate_size: int = 768
num_attention_heads: int = 12
num_hidden_layers: int = 12
vocab_size: int = 30522
attention_probs_dropout_prob: float = 0.1
hidden_dropout_prob: float = 0.1
layer_norm_eps: float = 1e-12
max_position_embeddings: int = 512
model_configs = {
"bert-base-uncased": ModelArgs(),
"bert-base-cased": ModelArgs(),
"bert-large-uncased": ModelArgs(
intermediate_size=1024, num_attention_heads=16, num_hidden_layers=24
),
"bert-large-cased": ModelArgs(
intermediate_size=1024, num_attention_heads=16, num_hidden_layers=24
),
}
class MultiHeadAttention(Module):
"""
Minor update to the MultiHeadAttention module to ensure that the
projections use bias.
"""
def __init__(
self,
dims: int,
num_heads: int,
query_input_dims: Optional[int] = None,
key_input_dims: Optional[int] = None,
value_input_dims: Optional[int] = None,
value_dims: Optional[int] = None,
value_output_dims: Optional[int] = None,
):
super().__init__()
if (dims % num_heads) != 0:
raise ValueError(
f"The input feature dimensions should be divisible by the number of heads ({dims} % {num_heads}) != 0"
)
query_input_dims = query_input_dims or dims
key_input_dims = key_input_dims or dims
value_input_dims = value_input_dims or key_input_dims
value_dims = value_dims or dims
value_output_dims = value_output_dims or dims
self.num_heads = num_heads
self.query_proj = Linear(query_input_dims, dims, True)
self.key_proj = Linear(key_input_dims, dims, True)
self.value_proj = Linear(value_input_dims, value_dims, True)
self.out_proj = Linear(value_dims, value_output_dims, True)
def __call__(self, queries, keys, values, mask=None):
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
num_heads = self.num_heads
B, L, D = queries.shape
_, S, _ = keys.shape
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
# Dimensions are [batch x num heads x sequence x hidden dim]
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys
if mask is not None:
mask = self.converrt_mask_to_additive_causal_mask(mask)
mask = mx.expand_dims(mask, (1, 2))
mask = mx.broadcast_to(mask, scores.shape)
scores = scores + mask.astype(scores.dtype)
scores = mx.softmax(scores, axis=-1)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(values_hat)
def converrt_mask_to_additive_causal_mask(
self, mask: mx.array, dtype: mx.Dtype = mx.float32
) -> mx.array:
mask = mask == 0
mask = mask.astype(dtype) * -1e9
return mask
class TransformerEncoderLayer(Module):
"""
A transformer encoder layer with (the original BERT) post-normalization.
"""
def __init__(
self,
dims: int,
num_heads: int,
mlp_dims: Optional[int] = None,
layer_norm_eps: float = 1e-12,
):
super().__init__()
mlp_dims = mlp_dims or dims * 4
self.attention = MultiHeadAttention(dims, num_heads)
self.ln1 = LayerNorm(dims, eps=layer_norm_eps)
self.ln2 = LayerNorm(dims, eps=layer_norm_eps)
self.linear1 = Linear(dims, mlp_dims)
self.linear2 = Linear(mlp_dims, dims)
self.gelu = nn.GELU()
def __call__(self, x, mask):
attention_out = self.attention(x, x, x, mask)
add_and_norm = self.ln1(x + attention_out)
ff = self.linear1(add_and_norm)
ff_gelu = self.gelu(ff)
ff_out = self.linear2(ff_gelu)
x = self.ln2(ff_out + add_and_norm)
return x
class TransformerEncoder(Module):
def __init__(
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None
):
super().__init__()
self.layers = [
TransformerEncoderLayer(dims, num_heads, mlp_dims)
for i in range(num_layers)
]
def __call__(self, x, mask):
for l in self.layers:
x = l(x, mask)
return x
class BertEmbeddings(nn.Module):
def __init__(self, config: ModelArgs):
self.word_embeddings = nn.Embedding(config.vocab_size, config.intermediate_size)
self.token_type_embeddings = nn.Embedding(2, config.intermediate_size)
self.position_embeddings = nn.Embedding(
config.max_position_embeddings, config.intermediate_size
)
self.norm = nn.LayerNorm(config.intermediate_size, eps=config.layer_norm_eps)
def __call__(self, input_ids: mx.array, token_type_ids: mx.array) -> mx.array:
words = self.word_embeddings(input_ids)
position = self.position_embeddings(
mx.broadcast_to(mx.arange(input_ids.shape[1]), input_ids.shape)
)
token_types = self.token_type_embeddings(token_type_ids)
embeddings = position + words + token_types
return self.norm(embeddings)
class Bert(nn.Module):
def __init__(self, config: ModelArgs):
self.embeddings = BertEmbeddings(config)
self.encoder = TransformerEncoder(
num_layers=config.num_hidden_layers,
dims=config.intermediate_size,
num_heads=config.num_attention_heads,
)
self.pooler = nn.Linear(config.intermediate_size, config.vocab_size)
def __call__(
self,
input_ids: mx.array,
token_type_ids: mx.array,
attention_mask: mx.array | None = None,
) -> tuple[mx.array, mx.array]:
x = self.embeddings(input_ids, token_type_ids)
y = self.encoder(x, attention_mask)
return y, mx.tanh(self.pooler(y[:, 0]))
def run(bert_model: str, mlx_model: str):
batch = [
"This is an example of BERT working on MLX.",
"A second string",
"This is another string.",
]
model = Bert(model_configs[bert_model])
weights = mx.load(mlx_model)
weights = tree_unflatten(list(weights.items()))
weights = tree_map(lambda p: mx.array(p), weights)
model.update(weights)
tokenizer = AutoTokenizer.from_pretrained(bert_model)
tokens = tokenizer(batch, return_tensors="np", padding=True)
tokens = {key: mx.array(v) for key, v in tokens.items()}
mlx_output, mlx_pooled = model(**tokens)
mlx_output = numpy.array(mlx_output)
mlx_pooled = numpy.array(mlx_pooled)
print("MLX BERT:")
print(mlx_output)
print("\n\nMLX Pooled:")
print(mlx_pooled[0, :20])
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert BERT weights to MLX.")
parser.add_argument(
"--bert-model",
type=str,
default="bert-base-uncased",
help="The huggingface name of the BERT model to save.",
)
parser.add_argument(
"--mlx-model",
type=str,
default="weights/bert-base-uncased.npz",
help="The output path for the MLX BERT weights.",
)
args = parser.parse_args()
run(args.bert_model, args.mlx_model)