Merge pull request #43 from jbarrow/main

BERT implementation
This commit is contained in:
Awni Hannun 2023-12-09 09:03:49 -08:00 committed by GitHub
commit 46c6bbe0a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 413 additions and 0 deletions

78
bert/README.md Normal file
View File

@ -0,0 +1,78 @@
# BERT
An implementation of BERT [(Devlin, et al., 2019)](https://aclanthology.org/N19-1423/) within MLX.
## Downloading and Converting Weights
The `convert.py` script relies on `transformers` to download the weights, and exports them as a single `.npz` file.
```
python convert.py \
--bert-model bert-base-uncased
--mlx-model weights/bert-base-uncased.npz
```
## Usage
To use the `Bert` model in your own code, you can load it with:
```python
from model import Bert, load_model
model, tokenizer = load_model(
"bert-base-uncased",
"weights/bert-base-uncased.npz")
batch = ["This is an example of BERT working on MLX."]
tokens = tokenizer(batch, return_tensors="np", padding=True)
tokens = {key: mx.array(v) for key, v in tokens.items()}
output, pooled = model(**tokens)
```
The `output` contains a `Batch x Tokens x Dims` tensor, representing a vector for every input token.
If you want to train anything at a **token-level**, you'll want to use this.
The `pooled` contains a `Batch x Dims` tensor, which is the pooled representation for each input.
If you want to train a **classification** model, you'll want to use this.
## Comparison with 🤗 `transformers` Implementation
In order to run the model, and have it forward inference on a batch of examples:
```sh
python model.py \
--bert-model bert-base-uncased \
--mlx-model weights/bert-base-uncased.npz
```
Which will show the following outputs:
```
MLX BERT:
[[[-0.52508914 -0.1993871 -0.28210318 ... -0.61125606 0.19114694
0.8227601 ]
[-0.8783862 -0.37107834 -0.52238125 ... -0.5067165 1.0847603
0.31066895]
[-0.70010054 -0.5424497 -0.26593682 ... -0.2688697 0.38338926
0.6557663 ]
...
```
They can be compared against the 🤗 implementation with:
```sh
python hf_model.py \
--bert-model bert-base-uncased
```
Which will show:
```
HF BERT:
[[[-0.52508944 -0.1993877 -0.28210333 ... -0.6112575 0.19114678
0.8227603 ]
[-0.878387 -0.371079 -0.522381 ... -0.50671494 1.0847601
0.31066933]
[-0.7001008 -0.5424504 -0.26593733 ... -0.26887015 0.38339025
0.65576553]
...
```

47
bert/convert.py Normal file
View File

@ -0,0 +1,47 @@
from transformers import BertModel
import argparse
import numpy
def replace_key(key: str) -> str:
key = key.replace(".layer.", ".layers.")
key = key.replace(".self.key.", ".key_proj.")
key = key.replace(".self.query.", ".query_proj.")
key = key.replace(".self.value.", ".value_proj.")
key = key.replace(".attention.output.dense.", ".attention.out_proj.")
key = key.replace(".attention.output.LayerNorm.", ".ln1.")
key = key.replace(".output.LayerNorm.", ".ln2.")
key = key.replace(".intermediate.dense.", ".linear1.")
key = key.replace(".output.dense.", ".linear2.")
key = key.replace(".LayerNorm.", ".norm.")
key = key.replace("pooler.dense.", "pooler.")
return key
def convert(bert_model: str, mlx_model: str) -> None:
model = BertModel.from_pretrained(bert_model)
# save the tensors
tensors = {
replace_key(key): tensor.numpy() for key, tensor in model.state_dict().items()
}
numpy.savez(mlx_model, **tensors)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert BERT weights to MLX.")
parser.add_argument(
"--bert-model",
choices=["bert-base-uncased", "bert-base-cased", "bert-large-uncased", "bert-large-cased"],
default="bert-base-uncased",
help="The huggingface name of the BERT model to save.",
)
parser.add_argument(
"--mlx-model",
type=str,
default="weights/bert-base-uncased.npz",
help="The output path for the MLX BERT weights.",
)
args = parser.parse_args()
convert(args.bert_model, args.mlx_model)

36
bert/hf_model.py Normal file
View File

@ -0,0 +1,36 @@
from transformers import AutoModel, AutoTokenizer
import argparse
def run(bert_model: str):
batch = [
"This is an example of BERT working on MLX.",
"A second string",
"This is another string.",
]
tokenizer = AutoTokenizer.from_pretrained(bert_model)
torch_model = AutoModel.from_pretrained(bert_model)
torch_tokens = tokenizer(batch, return_tensors="pt", padding=True)
torch_forward = torch_model(**torch_tokens)
torch_output = torch_forward.last_hidden_state.detach().numpy()
torch_pooled = torch_forward.pooler_output.detach().numpy()
print("\n HF BERT:")
print(torch_output)
print("\n\n HF Pooled:")
print(torch_pooled[0, :20])
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the BERT model using HuggingFace Transformers.")
parser.add_argument(
"--bert-model",
choices=["bert-base-uncased", "bert-base-cased", "bert-large-uncased", "bert-large-cased"],
default="bert-base-uncased",
help="The huggingface name of the BERT model to save.",
)
args = parser.parse_args()
run(args.bert_model)

248
bert/model.py Normal file
View File

@ -0,0 +1,248 @@
from typing import Optional
from dataclasses import dataclass
from transformers import BertTokenizer
from mlx.utils import tree_unflatten
import mlx.core as mx
import mlx.nn as nn
import argparse
import numpy
import math
@dataclass
class ModelArgs:
intermediate_size: int = 768
num_attention_heads: int = 12
num_hidden_layers: int = 12
vocab_size: int = 30522
attention_probs_dropout_prob: float = 0.1
hidden_dropout_prob: float = 0.1
layer_norm_eps: float = 1e-12
max_position_embeddings: int = 512
model_configs = {
"bert-base-uncased": ModelArgs(),
"bert-base-cased": ModelArgs(),
"bert-large-uncased": ModelArgs(
intermediate_size=1024, num_attention_heads=16, num_hidden_layers=24
),
"bert-large-cased": ModelArgs(
intermediate_size=1024, num_attention_heads=16, num_hidden_layers=24
),
}
class MultiHeadAttention(nn.Module):
"""
Minor update to the MultiHeadAttention module to ensure that the
projections use bias.
"""
def __init__(
self,
dims: int,
num_heads: int,
query_input_dims: Optional[int] = None,
key_input_dims: Optional[int] = None,
value_input_dims: Optional[int] = None,
value_dims: Optional[int] = None,
value_output_dims: Optional[int] = None,
):
super().__init__()
if (dims % num_heads) != 0:
raise ValueError(
f"The input feature dimensions should be divisible by the number of heads ({dims} % {num_heads}) != 0"
)
query_input_dims = query_input_dims or dims
key_input_dims = key_input_dims or dims
value_input_dims = value_input_dims or key_input_dims
value_dims = value_dims or dims
value_output_dims = value_output_dims or dims
self.num_heads = num_heads
self.query_proj = nn.Linear(query_input_dims, dims, True)
self.key_proj = nn.Linear(key_input_dims, dims, True)
self.value_proj = nn.Linear(value_input_dims, value_dims, True)
self.out_proj = nn.Linear(value_dims, value_output_dims, True)
def __call__(self, queries, keys, values, mask=None):
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
num_heads = self.num_heads
B, L, D = queries.shape
_, S, _ = keys.shape
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
# Dimensions are [batch x num heads x sequence x hidden dim]
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys
if mask is not None:
mask = self.convert_mask_to_additive_causal_mask(mask)
mask = mx.expand_dims(mask, (1, 2))
mask = mx.broadcast_to(mask, scores.shape)
scores = scores + mask.astype(scores.dtype)
scores = mx.softmax(scores, axis=-1)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(values_hat)
def convert_mask_to_additive_causal_mask(
self, mask: mx.array, dtype: mx.Dtype = mx.float32
) -> mx.array:
mask = mask == 0
mask = mask.astype(dtype) * -1e9
return mask
class TransformerEncoderLayer(nn.Module):
"""
A transformer encoder layer with (the original BERT) post-normalization.
"""
def __init__(
self,
dims: int,
num_heads: int,
mlp_dims: Optional[int] = None,
layer_norm_eps: float = 1e-12,
):
super().__init__()
mlp_dims = mlp_dims or dims * 4
self.attention = MultiHeadAttention(dims, num_heads)
self.ln1 = nn.LayerNorm(dims, eps=layer_norm_eps)
self.ln2 = nn.LayerNorm(dims, eps=layer_norm_eps)
self.linear1 = nn.Linear(dims, mlp_dims)
self.linear2 = nn.Linear(mlp_dims, dims)
self.gelu = nn.GELU()
def __call__(self, x, mask):
attention_out = self.attention(x, x, x, mask)
add_and_norm = self.ln1(x + attention_out)
ff = self.linear1(add_and_norm)
ff_gelu = self.gelu(ff)
ff_out = self.linear2(ff_gelu)
x = self.ln2(ff_out + add_and_norm)
return x
class TransformerEncoder(nn.Module):
def __init__(
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None
):
super().__init__()
self.layers = [
TransformerEncoderLayer(dims, num_heads, mlp_dims)
for i in range(num_layers)
]
def __call__(self, x, mask):
for layer in self.layers:
x = layer(x, mask)
return x
class BertEmbeddings(nn.Module):
def __init__(self, config: ModelArgs):
self.word_embeddings = nn.Embedding(config.vocab_size, config.intermediate_size)
self.token_type_embeddings = nn.Embedding(2, config.intermediate_size)
self.position_embeddings = nn.Embedding(
config.max_position_embeddings, config.intermediate_size
)
self.norm = nn.LayerNorm(config.intermediate_size, eps=config.layer_norm_eps)
def __call__(self, input_ids: mx.array, token_type_ids: mx.array) -> mx.array:
words = self.word_embeddings(input_ids)
position = self.position_embeddings(
mx.broadcast_to(mx.arange(input_ids.shape[1]), input_ids.shape)
)
token_types = self.token_type_embeddings(token_type_ids)
embeddings = position + words + token_types
return self.norm(embeddings)
class Bert(nn.Module):
def __init__(self, config: ModelArgs):
self.embeddings = BertEmbeddings(config)
self.encoder = TransformerEncoder(
num_layers=config.num_hidden_layers,
dims=config.intermediate_size,
num_heads=config.num_attention_heads,
)
self.pooler = nn.Linear(config.intermediate_size, config.vocab_size)
def __call__(
self,
input_ids: mx.array,
token_type_ids: mx.array,
attention_mask: Optional[mx.array] = None,
) -> tuple[mx.array, mx.array]:
x = self.embeddings(input_ids, token_type_ids)
y = self.encoder(x, attention_mask)
return y, mx.tanh(self.pooler(y[:, 0]))
def load_model(bert_model: str, weights_path: str) -> tuple[Bert, BertTokenizer]:
# load the weights npz
weights = mx.load(weights_path)
weights = tree_unflatten(list(weights.items()))
# create and update the model
model = Bert(model_configs[bert_model])
model.update(weights)
tokenizer = BertTokenizer.from_pretrained(bert_model)
return model, tokenizer
def run(bert_model: str, mlx_model: str):
model, tokenizer = load_model(bert_model, mlx_model)
batch = [
"This is an example of BERT working on MLX.",
"A second string",
"This is another string.",
]
tokens = tokenizer(batch, return_tensors="np", padding=True)
tokens = {key: mx.array(v) for key, v in tokens.items()}
mlx_output, mlx_pooled = model(**tokens)
mlx_output = numpy.array(mlx_output)
mlx_pooled = numpy.array(mlx_pooled)
print("MLX BERT:")
print(mlx_output)
print("\n\nMLX Pooled:")
print(mlx_pooled[0, :20])
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the BERT model using MLX.")
parser.add_argument(
"--bert-model",
type=str,
default="bert-base-uncased",
help="The huggingface name of the BERT model to save.",
)
parser.add_argument(
"--mlx-model",
type=str,
default="weights/bert-base-uncased.npz",
help="The path of the stored MLX BERT weights (npz file).",
)
args = parser.parse_args()
run(args.bert_model, args.mlx_model)

3
bert/requirements.txt Normal file
View File

@ -0,0 +1,3 @@
mlx
transformers
numpy

1
bert/weights/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
*.npz