BERT implementation

This commit is contained in:
Joe Barrow 2023-12-08 05:14:11 -05:00
parent ff3cc56c8d
commit 4e5b8ceafe
5 changed files with 399 additions and 0 deletions

68
bert/README.md Normal file
View File

@ -0,0 +1,68 @@
# mlxbert
A BERT implementation in Apple's new MLX framework.
## Dependency Installation
```sh
poetry install --no-root
```
If you don't want to do that, simply make sure you have the following dependencies installed:
- `mlx`
- `transformers`
- `numpy`
## Download and Convert
```
python convert.py \
--bert-model bert-base-uncased
--mlx-model weights/bert-base-uncased.npz
```
## Run the Model
Right now, this is just a test to show tha the outputs from mlx and huggingface don't change all that much.
```sh
python model.py \
--bert-model bert-base-uncased \
--mlx-model weights/bert-base-uncased.npz
```
Which will show the following outputs:
```
MLX BERT:
[[[-0.17057164 0.08602728 -0.12471077 ... -0.09469379 -0.00275938
0.28314582]
[ 0.15222196 -0.48997563 -0.26665813 ... -0.19935863 -0.17162783
-0.51360303]
[ 0.9460105 0.1358298 -0.2945672 ... 0.00868467 -0.90271163
-0.2785422 ]]]
```
They can be compared against the 🤗 implementation with:
```sh
python hf_model.py \
--bert-model bert-base-uncased
```
Which will show:
```
HF BERT:
[[[-0.17057131 0.08602707 -0.12471108 ... -0.09469365 -0.00275959
0.28314728]
[ 0.15222463 -0.48997375 -0.26665992 ... -0.19936043 -0.17162988
-0.5136028 ]
[ 0.946011 0.13582966 -0.29456618 ... 0.00868565 -0.90271175
-0.27854213]]]
```
## To do's
- [x] fix position encodings
- [x] bert large and cased variants loaded
- [x] example usage

48
bert/convert.py Normal file
View File

@ -0,0 +1,48 @@
from transformers import BertModel
import argparse
import numpy
def replace_key(key: str) -> str:
key = key.replace(".layer.", ".layers.")
key = key.replace(".self.key.", ".key_proj.")
key = key.replace(".self.query.", ".query_proj.")
key = key.replace(".self.value.", ".value_proj.")
key = key.replace(".attention.output.dense.", ".attention.out_proj.")
key = key.replace(".attention.output.LayerNorm.", ".ln1.")
key = key.replace(".output.LayerNorm.", ".ln2.")
key = key.replace(".intermediate.dense.", ".linear1.")
key = key.replace(".output.dense.", ".linear2.")
key = key.replace(".LayerNorm.", ".norm.")
key = key.replace("pooler.dense.", "pooler.")
return key
def convert(bert_model: str, mlx_model: str) -> None:
model = BertModel.from_pretrained(bert_model)
# save the tensors
tensors = {
replace_key(key): tensor.numpy() for key, tensor in model.state_dict().items()
}
numpy.savez(mlx_model, **tensors)
# save the tokenizer
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert BERT weights to MLX.")
parser.add_argument(
"--bert-model",
type=str,
default="bert-base-uncased",
help="The huggingface name of the BERT model to save.",
)
parser.add_argument(
"--mlx-model",
type=str,
default="weights/bert-base-uncased.npz",
help="The output path for the MLX BERT weights.",
)
args = parser.parse_args()
convert(args.bert_model, args.mlx_model)

36
bert/hf_model.py Normal file
View File

@ -0,0 +1,36 @@
from transformers import AutoModel, AutoTokenizer
import argparse
def run(bert_model: str):
batch = [
"This is an example of BERT working on MLX.",
"A second string",
"This is another string.",
]
tokenizer = AutoTokenizer.from_pretrained(bert_model)
torch_model = AutoModel.from_pretrained(bert_model)
torch_tokens = tokenizer(batch, return_tensors="pt", padding=True)
torch_forward = torch_model(**torch_tokens)
torch_output = torch_forward.last_hidden_state.detach().numpy()
torch_pooled = torch_forward.pooler_output.detach().numpy()
print("\n HF BERT:")
print(torch_output)
print("\n\n HF Pooled:")
print(torch_pooled[0, :20])
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--bert-model",
type=str,
default="bert-base-uncased",
help="The huggingface name of the BERT model to save.",
)
args = parser.parse_args()
run(args.bert_model)

246
bert/model.py Normal file
View File

@ -0,0 +1,246 @@
from typing import Optional
from dataclasses import dataclass
from mlx.utils import tree_unflatten, tree_map
from mlx.nn.layers.base import Module
from mlx.nn.layers.linear import Linear
from mlx.nn.layers.normalization import LayerNorm
from transformers import AutoTokenizer
import mlx.core as mx
import mlx.nn as nn
import argparse
import numpy
import math
@dataclass
class ModelArgs:
intermediate_size: int = 768
num_attention_heads: int = 12
num_hidden_layers: int = 12
vocab_size: int = 30522
attention_probs_dropout_prob: float = 0.1
hidden_dropout_prob: float = 0.1
layer_norm_eps: float = 1e-12
max_position_embeddings: int = 512
model_configs = {
"bert-base-uncased": ModelArgs(),
"bert-base-cased": ModelArgs(),
"bert-large-uncased": ModelArgs(
intermediate_size=1024, num_attention_heads=16, num_hidden_layers=24
),
"bert-large-cased": ModelArgs(
intermediate_size=1024, num_attention_heads=16, num_hidden_layers=24
),
}
class MultiHeadAttention(Module):
"""
Minor update to the MultiHeadAttention module to ensure that the
projections use bias.
"""
def __init__(
self,
dims: int,
num_heads: int,
query_input_dims: Optional[int] = None,
key_input_dims: Optional[int] = None,
value_input_dims: Optional[int] = None,
value_dims: Optional[int] = None,
value_output_dims: Optional[int] = None,
):
super().__init__()
if (dims % num_heads) != 0:
raise ValueError(
f"The input feature dimensions should be divisible by the number of heads ({dims} % {num_heads}) != 0"
)
query_input_dims = query_input_dims or dims
key_input_dims = key_input_dims or dims
value_input_dims = value_input_dims or key_input_dims
value_dims = value_dims or dims
value_output_dims = value_output_dims or dims
self.num_heads = num_heads
self.query_proj = Linear(query_input_dims, dims, True)
self.key_proj = Linear(key_input_dims, dims, True)
self.value_proj = Linear(value_input_dims, value_dims, True)
self.out_proj = Linear(value_dims, value_output_dims, True)
def __call__(self, queries, keys, values, mask=None):
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
num_heads = self.num_heads
B, L, D = queries.shape
_, S, _ = keys.shape
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
# Dimensions are [batch x num heads x sequence x hidden dim]
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys
if mask is not None:
mask = self.converrt_mask_to_additive_causal_mask(mask)
mask = mx.expand_dims(mask, (1, 2))
mask = mx.broadcast_to(mask, scores.shape)
scores = scores + mask.astype(scores.dtype)
scores = mx.softmax(scores, axis=-1)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(values_hat)
def converrt_mask_to_additive_causal_mask(
self, mask: mx.array, dtype: mx.Dtype = mx.float32
) -> mx.array:
mask = mask == 0
mask = mask.astype(dtype) * -1e9
return mask
class TransformerEncoderLayer(Module):
"""
A transformer encoder layer with (the original BERT) post-normalization.
"""
def __init__(
self,
dims: int,
num_heads: int,
mlp_dims: Optional[int] = None,
layer_norm_eps: float = 1e-12,
):
super().__init__()
mlp_dims = mlp_dims or dims * 4
self.attention = MultiHeadAttention(dims, num_heads)
self.ln1 = LayerNorm(dims, eps=layer_norm_eps)
self.ln2 = LayerNorm(dims, eps=layer_norm_eps)
self.linear1 = Linear(dims, mlp_dims)
self.linear2 = Linear(mlp_dims, dims)
self.gelu = nn.GELU()
def __call__(self, x, mask):
attention_out = self.attention(x, x, x, mask)
add_and_norm = self.ln1(x + attention_out)
ff = self.linear1(add_and_norm)
ff_gelu = self.gelu(ff)
ff_out = self.linear2(ff_gelu)
x = self.ln2(ff_out + add_and_norm)
return x
class TransformerEncoder(Module):
def __init__(
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None
):
super().__init__()
self.layers = [
TransformerEncoderLayer(dims, num_heads, mlp_dims)
for i in range(num_layers)
]
def __call__(self, x, mask):
for l in self.layers:
x = l(x, mask)
return x
class BertEmbeddings(nn.Module):
def __init__(self, config: ModelArgs):
self.word_embeddings = nn.Embedding(config.vocab_size, config.intermediate_size)
self.token_type_embeddings = nn.Embedding(2, config.intermediate_size)
self.position_embeddings = nn.Embedding(
config.max_position_embeddings, config.intermediate_size
)
self.norm = nn.LayerNorm(config.intermediate_size, eps=config.layer_norm_eps)
def __call__(self, input_ids: mx.array, token_type_ids: mx.array) -> mx.array:
words = self.word_embeddings(input_ids)
position = self.position_embeddings(
mx.broadcast_to(mx.arange(input_ids.shape[1]), input_ids.shape)
)
token_types = self.token_type_embeddings(token_type_ids)
embeddings = position + words + token_types
return self.norm(embeddings)
class Bert(nn.Module):
def __init__(self, config: ModelArgs):
self.embeddings = BertEmbeddings(config)
self.encoder = TransformerEncoder(
num_layers=config.num_hidden_layers,
dims=config.intermediate_size,
num_heads=config.num_attention_heads,
)
self.pooler = nn.Linear(config.intermediate_size, config.vocab_size)
def __call__(
self,
input_ids: mx.array,
token_type_ids: mx.array,
attention_mask: mx.array | None = None,
) -> tuple[mx.array, mx.array]:
x = self.embeddings(input_ids, token_type_ids)
y = self.encoder(x, attention_mask)
return y, mx.tanh(self.pooler(y[:, 0]))
def run(bert_model: str, mlx_model: str):
batch = [
"This is an example of BERT working on MLX.",
"A second string",
"This is another string.",
]
model = Bert(model_configs[bert_model])
weights = mx.load(mlx_model)
weights = tree_unflatten(list(weights.items()))
weights = tree_map(lambda p: mx.array(p), weights)
model.update(weights)
tokenizer = AutoTokenizer.from_pretrained(bert_model)
tokens = tokenizer(batch, return_tensors="np", padding=True)
tokens = {key: mx.array(v) for key, v in tokens.items()}
mlx_output, mlx_pooled = model(**tokens)
mlx_output = numpy.array(mlx_output)
mlx_pooled = numpy.array(mlx_pooled)
print("MLX BERT:")
print(mlx_output)
print("\n\nMLX Pooled:")
print(mlx_pooled[0, :20])
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert BERT weights to MLX.")
parser.add_argument(
"--bert-model",
type=str,
default="bert-base-uncased",
help="The huggingface name of the BERT model to save.",
)
parser.add_argument(
"--mlx-model",
type=str,
default="weights/bert-base-uncased.npz",
help="The output path for the MLX BERT weights.",
)
args = parser.parse_args()
run(args.bert_model, args.mlx_model)

1
bert/weights/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
*.npz