From bbd7172eefefc0be9e56a910a7601709fb16b5bd Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 9 Jan 2024 08:44:51 -0800 Subject: [PATCH] Some fixes / cleanup for BERT example (#269) * some fixes/cleaning for bert + test * nit --- bert/README.md | 61 +++++++++++++++--------------------------------- bert/hf_model.py | 43 ---------------------------------- bert/model.py | 56 +++++++++++++++++++------------------------- bert/test.py | 34 +++++++++++++++++++++++++++ 4 files changed, 77 insertions(+), 117 deletions(-) delete mode 100644 bert/hf_model.py create mode 100644 bert/test.py diff --git a/bert/README.md b/bert/README.md index 70bc39a9..42e5e957 100644 --- a/bert/README.md +++ b/bert/README.md @@ -2,9 +2,15 @@ An implementation of BERT [(Devlin, et al., 2019)](https://aclanthology.org/N19-1423/) within MLX. -## Downloading and Converting Weights +## Setup -The `convert.py` script relies on `transformers` to download the weights, and exports them as a single `.npz` file. +Install the requirements: + +``` +pip install -r requirements.txt +``` + +Then convert the weights with: ``` python convert.py \ @@ -30,49 +36,20 @@ tokens = {key: mx.array(v) for key, v in tokens.items()} output, pooled = model(**tokens) ``` -The `output` contains a `Batch x Tokens x Dims` tensor, representing a vector for every input token. -If you want to train anything at a **token-level**, you'll want to use this. +The `output` contains a `Batch x Tokens x Dims` tensor, representing a vector +for every input token. If you want to train anything at a **token-level**, +you'll want to use this. -The `pooled` contains a `Batch x Dims` tensor, which is the pooled representation for each input. -If you want to train a **classification** model, you'll want to use this. +The `pooled` contains a `Batch x Dims` tensor, which is the pooled +representation for each input. If you want to train a **classification** +model, you'll want to use this. -## Comparison with 🤗 `transformers` Implementation -In order to run the model, and have it forward inference on a batch of examples: +## Test + +You can check the output for the default model (`bert-base-uncased`) matches the +Hugging Face version with: -```sh -python model.py \ - --bert-model bert-base-uncased \ - --mlx-model weights/bert-base-uncased.npz ``` - -Which will show the following outputs: -``` -MLX BERT: -[[[-0.52508914 -0.1993871 -0.28210318 ... -0.61125606 0.19114694 - 0.8227601 ] - [-0.8783862 -0.37107834 -0.52238125 ... -0.5067165 1.0847603 - 0.31066895] - [-0.70010054 -0.5424497 -0.26593682 ... -0.2688697 0.38338926 - 0.6557663 ] - ... -``` - -They can be compared against the 🤗 implementation with: - -```sh -python hf_model.py \ - --bert-model bert-base-uncased -``` - -Which will show: -``` - HF BERT: -[[[-0.52508944 -0.1993877 -0.28210333 ... -0.6112575 0.19114678 - 0.8227603 ] - [-0.878387 -0.371079 -0.522381 ... -0.50671494 1.0847601 - 0.31066933] - [-0.7001008 -0.5424504 -0.26593733 ... -0.26887015 0.38339025 - 0.65576553] - ... +python test.py ``` diff --git a/bert/hf_model.py b/bert/hf_model.py deleted file mode 100644 index 7b8ad722..00000000 --- a/bert/hf_model.py +++ /dev/null @@ -1,43 +0,0 @@ -import argparse - -from transformers import AutoModel, AutoTokenizer - - -def run(bert_model: str): - batch = [ - "This is an example of BERT working on MLX.", - "A second string", - "This is another string.", - ] - - tokenizer = AutoTokenizer.from_pretrained(bert_model) - torch_model = AutoModel.from_pretrained(bert_model) - torch_tokens = tokenizer(batch, return_tensors="pt", padding=True) - torch_forward = torch_model(**torch_tokens) - torch_output = torch_forward.last_hidden_state.detach().numpy() - torch_pooled = torch_forward.pooler_output.detach().numpy() - - print("\n HF BERT:") - print(torch_output) - print("\n\n HF Pooled:") - print(torch_pooled[0, :20]) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Run the BERT model using Hugging Face Transformers." - ) - parser.add_argument( - "--bert-model", - choices=[ - "bert-base-uncased", - "bert-base-cased", - "bert-large-uncased", - "bert-large-cased", - ], - default="bert-base-uncased", - help="The huggingface name of the BERT model to save.", - ) - args = parser.parse_args() - - run(args.bert_model) diff --git a/bert/model.py b/bert/model.py index f34b3272..11a24659 100644 --- a/bert/model.py +++ b/bert/model.py @@ -1,6 +1,7 @@ import argparse from dataclasses import dataclass -from typing import Optional +from pathlib import Path +from typing import List, Optional import mlx.core as mx import mlx.nn as nn @@ -12,7 +13,7 @@ from transformers import BertTokenizer @dataclass class ModelArgs: - intermediate_size: int = 768 + dim: int = 768 num_attention_heads: int = 12 num_hidden_layers: int = 12 vocab_size: int = 30522 @@ -26,10 +27,10 @@ model_configs = { "bert-base-uncased": ModelArgs(), "bert-base-cased": ModelArgs(), "bert-large-uncased": ModelArgs( - intermediate_size=1024, num_attention_heads=16, num_hidden_layers=24 + dim=1024, num_attention_heads=16, num_hidden_layers=24 ), "bert-large-cased": ModelArgs( - intermediate_size=1024, num_attention_heads=16, num_hidden_layers=24 + dim=1024, num_attention_heads=16, num_hidden_layers=24 ), } @@ -86,12 +87,12 @@ class TransformerEncoder(nn.Module): class BertEmbeddings(nn.Module): def __init__(self, config: ModelArgs): - self.word_embeddings = nn.Embedding(config.vocab_size, config.intermediate_size) - self.token_type_embeddings = nn.Embedding(2, config.intermediate_size) + self.word_embeddings = nn.Embedding(config.vocab_size, config.dim) + self.token_type_embeddings = nn.Embedding(2, config.dim) self.position_embeddings = nn.Embedding( - config.max_position_embeddings, config.intermediate_size + config.max_position_embeddings, config.dim ) - self.norm = nn.LayerNorm(config.intermediate_size, eps=config.layer_norm_eps) + self.norm = nn.LayerNorm(config.dim, eps=config.layer_norm_eps) def __call__(self, input_ids: mx.array, token_type_ids: mx.array) -> mx.array: words = self.word_embeddings(input_ids) @@ -109,10 +110,10 @@ class Bert(nn.Module): self.embeddings = BertEmbeddings(config) self.encoder = TransformerEncoder( num_layers=config.num_hidden_layers, - dims=config.intermediate_size, + dims=config.dim, num_heads=config.num_attention_heads, ) - self.pooler = nn.Linear(config.intermediate_size, config.vocab_size) + self.pooler = nn.Linear(config.dim, config.dim) def __call__( self, @@ -132,39 +133,25 @@ class Bert(nn.Module): def load_model(bert_model: str, weights_path: str) -> tuple[Bert, BertTokenizer]: - # load the weights npz - weights = mx.load(weights_path) - weights = tree_unflatten(list(weights.items())) + if not Path(weights_path).exists(): + raise ValueError(f"No model weights found in {weights_path}") + # create and update the model model = Bert(model_configs[bert_model]) - model.update(weights) + model.load_weights(weights_path) tokenizer = BertTokenizer.from_pretrained(bert_model) return model, tokenizer -def run(bert_model: str, mlx_model: str): +def run(bert_model: str, mlx_model: str, batch: List[str]): model, tokenizer = load_model(bert_model, mlx_model) - batch = [ - "This is an example of BERT working on MLX.", - "A second string", - "This is another string.", - ] - tokens = tokenizer(batch, return_tensors="np", padding=True) tokens = {key: mx.array(v) for key, v in tokens.items()} - mlx_output, mlx_pooled = model(**tokens) - mlx_output = numpy.array(mlx_output) - mlx_pooled = numpy.array(mlx_pooled) - - print("MLX BERT:") - print(mlx_output) - - print("\n\nMLX Pooled:") - print(mlx_pooled[0, :20]) + return model(**tokens) if __name__ == "__main__": @@ -181,6 +168,11 @@ if __name__ == "__main__": default="weights/bert-base-uncased.npz", help="The path of the stored MLX BERT weights (npz file).", ) + parser.add_argument( + "--text", + type=str, + default="This is an example of BERT working in MLX", + help="The text to generate embeddings for.", + ) args = parser.parse_args() - - run(args.bert_model, args.mlx_model) + run(args.bert_model, args.mlx_model, args.text) diff --git a/bert/test.py b/bert/test.py new file mode 100644 index 00000000..089fc45f --- /dev/null +++ b/bert/test.py @@ -0,0 +1,34 @@ +from typing import List + +import model +import numpy as np +from transformers import AutoModel, AutoTokenizer + + +def run_torch(bert_model: str, batch: List[str]): + tokenizer = AutoTokenizer.from_pretrained(bert_model) + torch_model = AutoModel.from_pretrained(bert_model) + torch_tokens = tokenizer(batch, return_tensors="pt", padding=True) + torch_forward = torch_model(**torch_tokens) + torch_output = torch_forward.last_hidden_state.detach().numpy() + torch_pooled = torch_forward.pooler_output.detach().numpy() + return torch_output, torch_pooled + + +if __name__ == "__main__": + bert_model = "bert-base-uncased" + mlx_model = "weights/bert-base-uncased.npz" + batch = [ + "This is an example of BERT working in MLX.", + "A second string", + "This is another string.", + ] + torch_output, torch_pooled = run_torch(bert_model, batch) + mlx_output, mlx_pooled = model.run(bert_model, mlx_model, batch) + assert np.allclose( + torch_output, mlx_output, rtol=1e-4, atol=1e-5 + ), "Model output is different" + assert np.allclose( + torch_pooled, mlx_pooled, rtol=1e-4, atol=1e-5 + ), "Model pooled output is different" + print("Tests pass :)")