Cleaning implementation for merge

This commit is contained in:
Joe Barrow
2023-12-09 10:41:15 -05:00
parent e05ee57bab
commit 7320456226
4 changed files with 39 additions and 38 deletions

View File

@@ -1,10 +1,7 @@
from typing import Optional
from dataclasses import dataclass
from mlx.utils import tree_unflatten, tree_map
from mlx.nn.layers.base import Module
from mlx.nn.layers.linear import Linear
from mlx.nn.layers.normalization import LayerNorm
from transformers import AutoTokenizer
from transformers import BertTokenizer
from mlx.utils import tree_unflatten
import mlx.core as mx
import mlx.nn as nn
@@ -37,7 +34,7 @@ model_configs = {
}
class MultiHeadAttention(Module):
class MultiHeadAttention(nn.Module):
"""
Minor update to the MultiHeadAttention module to ensure that the
projections use bias.
@@ -67,10 +64,10 @@ class MultiHeadAttention(Module):
value_output_dims = value_output_dims or dims
self.num_heads = num_heads
self.query_proj = Linear(query_input_dims, dims, True)
self.key_proj = Linear(key_input_dims, dims, True)
self.value_proj = Linear(value_input_dims, value_dims, True)
self.out_proj = Linear(value_dims, value_output_dims, True)
self.query_proj = nn.Linear(query_input_dims, dims, True)
self.key_proj = nn.Linear(key_input_dims, dims, True)
self.value_proj = nn.Linear(value_input_dims, value_dims, True)
self.out_proj = nn.Linear(value_dims, value_output_dims, True)
def __call__(self, queries, keys, values, mask=None):
queries = self.query_proj(queries)
@@ -105,7 +102,7 @@ class MultiHeadAttention(Module):
return mask
class TransformerEncoderLayer(Module):
class TransformerEncoderLayer(nn.Module):
"""
A transformer encoder layer with (the original BERT) post-normalization.
"""
@@ -120,10 +117,10 @@ class TransformerEncoderLayer(Module):
super().__init__()
mlp_dims = mlp_dims or dims * 4
self.attention = MultiHeadAttention(dims, num_heads)
self.ln1 = LayerNorm(dims, eps=layer_norm_eps)
self.ln2 = LayerNorm(dims, eps=layer_norm_eps)
self.linear1 = Linear(dims, mlp_dims)
self.linear2 = Linear(mlp_dims, dims)
self.ln1 = nn.LayerNorm(dims, eps=layer_norm_eps)
self.ln2 = nn.LayerNorm(dims, eps=layer_norm_eps)
self.linear1 = nn.Linear(dims, mlp_dims)
self.linear2 = nn.Linear(mlp_dims, dims)
self.gelu = nn.GELU()
def __call__(self, x, mask):
@@ -138,7 +135,7 @@ class TransformerEncoderLayer(Module):
return x
class TransformerEncoder(Module):
class TransformerEncoder(nn.Module):
def __init__(
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None
):
@@ -149,8 +146,8 @@ class TransformerEncoder(Module):
]
def __call__(self, x, mask):
for l in self.layers:
x = l(x, mask)
for layer in self.layers:
x = layer(x, mask)
return x
@@ -196,23 +193,28 @@ class Bert(nn.Module):
return y, mx.tanh(self.pooler(y[:, 0]))
def load_model(bert_model: str, weights_path: str) -> tuple[Bert, BertTokenizer]:
# load the weights npz
weights = mx.load(weights_path)
weights = tree_unflatten(list(weights.items()))
# create and update the model
model = Bert(model_configs[bert_model])
model.update(weights)
tokenizer = BertTokenizer.from_pretrained(bert_model)
return model, tokenizer
def run(bert_model: str, mlx_model: str):
model, tokenizer = load_model(bert_model, mlx_model)
batch = [
"This is an example of BERT working on MLX.",
"A second string",
"This is another string.",
]
model = Bert(model_configs[bert_model])
weights = mx.load(mlx_model)
weights = tree_unflatten(list(weights.items()))
weights = tree_map(lambda p: mx.array(p), weights)
model.update(weights)
tokenizer = AutoTokenizer.from_pretrained(bert_model)
tokens = tokenizer(batch, return_tensors="np", padding=True)
tokens = {key: mx.array(v) for key, v in tokens.items()}
@@ -228,7 +230,7 @@ def run(bert_model: str, mlx_model: str):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert BERT weights to MLX.")
parser = argparse.ArgumentParser(description="Run the BERT model using MLX.")
parser.add_argument(
"--bert-model",
type=str,
@@ -239,8 +241,8 @@ if __name__ == "__main__":
"--mlx-model",
type=str,
default="weights/bert-base-uncased.npz",
help="The output path for the MLX BERT weights.",
help="The path of the stored MLX BERT weights (npz file).",
)
args = parser.parse_args()
run(args.bert_model, args.mlx_model)
run(args.bert_model, args.mlx_model)