Cleaning implementation for merge

This commit is contained in:
Joe Barrow 2023-12-09 10:41:15 -05:00
parent e05ee57bab
commit 7320456226
4 changed files with 39 additions and 38 deletions

View File

@ -1,6 +1,6 @@
# mlxbert # BERT
An implementation of BERT [(Devlin, et al., 2019)](https://aclanthology.org/N19-1423/) within mlx. An implementation of BERT [(Devlin, et al., 2019)](https://aclanthology.org/N19-1423/) within MLX.
## Downloading and Converting Weights ## Downloading and Converting Weights

View File

@ -26,14 +26,13 @@ def convert(bert_model: str, mlx_model: str) -> None:
replace_key(key): tensor.numpy() for key, tensor in model.state_dict().items() replace_key(key): tensor.numpy() for key, tensor in model.state_dict().items()
} }
numpy.savez(mlx_model, **tensors) numpy.savez(mlx_model, **tensors)
# save the tokenizer
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert BERT weights to MLX.") parser = argparse.ArgumentParser(description="Convert BERT weights to MLX.")
parser.add_argument( parser.add_argument(
"--bert-model", "--bert-model",
type=str, choices=["bert-base-uncased", "bert-base-cased", "bert-large-uncased", "bert-large-cased"],
default="bert-base-uncased", default="bert-base-uncased",
help="The huggingface name of the BERT model to save.", help="The huggingface name of the BERT model to save.",
) )

View File

@ -24,10 +24,10 @@ def run(bert_model: str):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser(description="Run the BERT model using HuggingFace Transformers.")
parser.add_argument( parser.add_argument(
"--bert-model", "--bert-model",
type=str, choices=["bert-base-uncased", "bert-base-cased", "bert-large-uncased", "bert-large-cased"],
default="bert-base-uncased", default="bert-base-uncased",
help="The huggingface name of the BERT model to save.", help="The huggingface name of the BERT model to save.",
) )

View File

@ -1,10 +1,7 @@
from typing import Optional from typing import Optional
from dataclasses import dataclass from dataclasses import dataclass
from mlx.utils import tree_unflatten, tree_map from transformers import BertTokenizer
from mlx.nn.layers.base import Module from mlx.utils import tree_unflatten
from mlx.nn.layers.linear import Linear
from mlx.nn.layers.normalization import LayerNorm
from transformers import AutoTokenizer
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -37,7 +34,7 @@ model_configs = {
} }
class MultiHeadAttention(Module): class MultiHeadAttention(nn.Module):
""" """
Minor update to the MultiHeadAttention module to ensure that the Minor update to the MultiHeadAttention module to ensure that the
projections use bias. projections use bias.
@ -67,10 +64,10 @@ class MultiHeadAttention(Module):
value_output_dims = value_output_dims or dims value_output_dims = value_output_dims or dims
self.num_heads = num_heads self.num_heads = num_heads
self.query_proj = Linear(query_input_dims, dims, True) self.query_proj = nn.Linear(query_input_dims, dims, True)
self.key_proj = Linear(key_input_dims, dims, True) self.key_proj = nn.Linear(key_input_dims, dims, True)
self.value_proj = Linear(value_input_dims, value_dims, True) self.value_proj = nn.Linear(value_input_dims, value_dims, True)
self.out_proj = Linear(value_dims, value_output_dims, True) self.out_proj = nn.Linear(value_dims, value_output_dims, True)
def __call__(self, queries, keys, values, mask=None): def __call__(self, queries, keys, values, mask=None):
queries = self.query_proj(queries) queries = self.query_proj(queries)
@ -105,7 +102,7 @@ class MultiHeadAttention(Module):
return mask return mask
class TransformerEncoderLayer(Module): class TransformerEncoderLayer(nn.Module):
""" """
A transformer encoder layer with (the original BERT) post-normalization. A transformer encoder layer with (the original BERT) post-normalization.
""" """
@ -120,10 +117,10 @@ class TransformerEncoderLayer(Module):
super().__init__() super().__init__()
mlp_dims = mlp_dims or dims * 4 mlp_dims = mlp_dims or dims * 4
self.attention = MultiHeadAttention(dims, num_heads) self.attention = MultiHeadAttention(dims, num_heads)
self.ln1 = LayerNorm(dims, eps=layer_norm_eps) self.ln1 = nn.LayerNorm(dims, eps=layer_norm_eps)
self.ln2 = LayerNorm(dims, eps=layer_norm_eps) self.ln2 = nn.LayerNorm(dims, eps=layer_norm_eps)
self.linear1 = Linear(dims, mlp_dims) self.linear1 = nn.Linear(dims, mlp_dims)
self.linear2 = Linear(mlp_dims, dims) self.linear2 = nn.Linear(mlp_dims, dims)
self.gelu = nn.GELU() self.gelu = nn.GELU()
def __call__(self, x, mask): def __call__(self, x, mask):
@ -138,7 +135,7 @@ class TransformerEncoderLayer(Module):
return x return x
class TransformerEncoder(Module): class TransformerEncoder(nn.Module):
def __init__( def __init__(
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None
): ):
@ -149,8 +146,8 @@ class TransformerEncoder(Module):
] ]
def __call__(self, x, mask): def __call__(self, x, mask):
for l in self.layers: for layer in self.layers:
x = l(x, mask) x = layer(x, mask)
return x return x
@ -196,23 +193,28 @@ class Bert(nn.Module):
return y, mx.tanh(self.pooler(y[:, 0])) return y, mx.tanh(self.pooler(y[:, 0]))
def load_model(bert_model: str, weights_path: str) -> tuple[Bert, BertTokenizer]:
# load the weights npz
weights = mx.load(weights_path)
weights = tree_unflatten(list(weights.items()))
# create and update the model
model = Bert(model_configs[bert_model])
model.update(weights)
tokenizer = BertTokenizer.from_pretrained(bert_model)
return model, tokenizer
def run(bert_model: str, mlx_model: str): def run(bert_model: str, mlx_model: str):
model, tokenizer = load_model(bert_model, mlx_model)
batch = [ batch = [
"This is an example of BERT working on MLX.", "This is an example of BERT working on MLX.",
"A second string", "A second string",
"This is another string.", "This is another string.",
] ]
model = Bert(model_configs[bert_model])
weights = mx.load(mlx_model)
weights = tree_unflatten(list(weights.items()))
weights = tree_map(lambda p: mx.array(p), weights)
model.update(weights)
tokenizer = AutoTokenizer.from_pretrained(bert_model)
tokens = tokenizer(batch, return_tensors="np", padding=True) tokens = tokenizer(batch, return_tensors="np", padding=True)
tokens = {key: mx.array(v) for key, v in tokens.items()} tokens = {key: mx.array(v) for key, v in tokens.items()}
@ -228,7 +230,7 @@ def run(bert_model: str, mlx_model: str):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert BERT weights to MLX.") parser = argparse.ArgumentParser(description="Run the BERT model using MLX.")
parser.add_argument( parser.add_argument(
"--bert-model", "--bert-model",
type=str, type=str,
@ -239,7 +241,7 @@ if __name__ == "__main__":
"--mlx-model", "--mlx-model",
type=str, type=str,
default="weights/bert-base-uncased.npz", default="weights/bert-base-uncased.npz",
help="The output path for the MLX BERT weights.", help="The path of the stored MLX BERT weights (npz file).",
) )
args = parser.parse_args() args = parser.parse_args()