From 73204562263d876a5056e9678ac5cf416dd2d98b Mon Sep 17 00:00:00 2001 From: Joe Barrow Date: Sat, 9 Dec 2023 10:41:15 -0500 Subject: [PATCH] Cleaning implementation for merge --- bert/README.md | 4 +-- bert/convert.py | 3 +-- bert/hf_model.py | 4 +-- bert/model.py | 66 +++++++++++++++++++++++++----------------------- 4 files changed, 39 insertions(+), 38 deletions(-) diff --git a/bert/README.md b/bert/README.md index e1b7a433..bb856ed3 100644 --- a/bert/README.md +++ b/bert/README.md @@ -1,6 +1,6 @@ -# mlxbert +# BERT -An implementation of BERT [(Devlin, et al., 2019)](https://aclanthology.org/N19-1423/) within mlx. +An implementation of BERT [(Devlin, et al., 2019)](https://aclanthology.org/N19-1423/) within MLX. ## Downloading and Converting Weights diff --git a/bert/convert.py b/bert/convert.py index d2b7b624..5a9298d6 100644 --- a/bert/convert.py +++ b/bert/convert.py @@ -26,14 +26,13 @@ def convert(bert_model: str, mlx_model: str) -> None: replace_key(key): tensor.numpy() for key, tensor in model.state_dict().items() } numpy.savez(mlx_model, **tensors) - # save the tokenizer if __name__ == "__main__": parser = argparse.ArgumentParser(description="Convert BERT weights to MLX.") parser.add_argument( "--bert-model", - type=str, + choices=["bert-base-uncased", "bert-base-cased", "bert-large-uncased", "bert-large-cased"], default="bert-base-uncased", help="The huggingface name of the BERT model to save.", ) diff --git a/bert/hf_model.py b/bert/hf_model.py index 13350e4a..9f73028d 100644 --- a/bert/hf_model.py +++ b/bert/hf_model.py @@ -24,10 +24,10 @@ def run(bert_model: str): if __name__ == "__main__": - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser(description="Run the BERT model using HuggingFace Transformers.") parser.add_argument( "--bert-model", - type=str, + choices=["bert-base-uncased", "bert-base-cased", "bert-large-uncased", "bert-large-cased"], default="bert-base-uncased", help="The huggingface name of the BERT model to save.", ) diff --git a/bert/model.py b/bert/model.py index 01ac294b..318f52ce 100644 --- a/bert/model.py +++ b/bert/model.py @@ -1,10 +1,7 @@ from typing import Optional from dataclasses import dataclass -from mlx.utils import tree_unflatten, tree_map -from mlx.nn.layers.base import Module -from mlx.nn.layers.linear import Linear -from mlx.nn.layers.normalization import LayerNorm -from transformers import AutoTokenizer +from transformers import BertTokenizer +from mlx.utils import tree_unflatten import mlx.core as mx import mlx.nn as nn @@ -37,7 +34,7 @@ model_configs = { } -class MultiHeadAttention(Module): +class MultiHeadAttention(nn.Module): """ Minor update to the MultiHeadAttention module to ensure that the projections use bias. @@ -67,10 +64,10 @@ class MultiHeadAttention(Module): value_output_dims = value_output_dims or dims self.num_heads = num_heads - self.query_proj = Linear(query_input_dims, dims, True) - self.key_proj = Linear(key_input_dims, dims, True) - self.value_proj = Linear(value_input_dims, value_dims, True) - self.out_proj = Linear(value_dims, value_output_dims, True) + self.query_proj = nn.Linear(query_input_dims, dims, True) + self.key_proj = nn.Linear(key_input_dims, dims, True) + self.value_proj = nn.Linear(value_input_dims, value_dims, True) + self.out_proj = nn.Linear(value_dims, value_output_dims, True) def __call__(self, queries, keys, values, mask=None): queries = self.query_proj(queries) @@ -105,7 +102,7 @@ class MultiHeadAttention(Module): return mask -class TransformerEncoderLayer(Module): +class TransformerEncoderLayer(nn.Module): """ A transformer encoder layer with (the original BERT) post-normalization. """ @@ -120,10 +117,10 @@ class TransformerEncoderLayer(Module): super().__init__() mlp_dims = mlp_dims or dims * 4 self.attention = MultiHeadAttention(dims, num_heads) - self.ln1 = LayerNorm(dims, eps=layer_norm_eps) - self.ln2 = LayerNorm(dims, eps=layer_norm_eps) - self.linear1 = Linear(dims, mlp_dims) - self.linear2 = Linear(mlp_dims, dims) + self.ln1 = nn.LayerNorm(dims, eps=layer_norm_eps) + self.ln2 = nn.LayerNorm(dims, eps=layer_norm_eps) + self.linear1 = nn.Linear(dims, mlp_dims) + self.linear2 = nn.Linear(mlp_dims, dims) self.gelu = nn.GELU() def __call__(self, x, mask): @@ -138,7 +135,7 @@ class TransformerEncoderLayer(Module): return x -class TransformerEncoder(Module): +class TransformerEncoder(nn.Module): def __init__( self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None ): @@ -149,8 +146,8 @@ class TransformerEncoder(Module): ] def __call__(self, x, mask): - for l in self.layers: - x = l(x, mask) + for layer in self.layers: + x = layer(x, mask) return x @@ -196,23 +193,28 @@ class Bert(nn.Module): return y, mx.tanh(self.pooler(y[:, 0])) +def load_model(bert_model: str, weights_path: str) -> tuple[Bert, BertTokenizer]: + # load the weights npz + weights = mx.load(weights_path) + weights = tree_unflatten(list(weights.items())) + # create and update the model + model = Bert(model_configs[bert_model]) + model.update(weights) + + tokenizer = BertTokenizer.from_pretrained(bert_model) + + return model, tokenizer + + def run(bert_model: str, mlx_model: str): + model, tokenizer = load_model(bert_model, mlx_model) + batch = [ "This is an example of BERT working on MLX.", "A second string", "This is another string.", ] - - model = Bert(model_configs[bert_model]) - - weights = mx.load(mlx_model) - weights = tree_unflatten(list(weights.items())) - weights = tree_map(lambda p: mx.array(p), weights) - - model.update(weights) - - tokenizer = AutoTokenizer.from_pretrained(bert_model) - + tokens = tokenizer(batch, return_tensors="np", padding=True) tokens = {key: mx.array(v) for key, v in tokens.items()} @@ -228,7 +230,7 @@ def run(bert_model: str, mlx_model: str): if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Convert BERT weights to MLX.") + parser = argparse.ArgumentParser(description="Run the BERT model using MLX.") parser.add_argument( "--bert-model", type=str, @@ -239,8 +241,8 @@ if __name__ == "__main__": "--mlx-model", type=str, default="weights/bert-base-uncased.npz", - help="The output path for the MLX BERT weights.", + help="The path of the stored MLX BERT weights (npz file).", ) args = parser.parse_args() - run(args.bert_model, args.mlx_model) \ No newline at end of file + run(args.bert_model, args.mlx_model)