Enable more BERT models (#580)

* Update convert.py

* Update model.py

* Update test.py

* Update model.py

* Update convert.py

* Add files via upload

* Update convert.py

* format

* nit

* nit

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
yzimmermann
2024-03-20 01:21:33 +01:00
committed by GitHub
parent b0bcd86a40
commit 4680ef4413
4 changed files with 76 additions and 68 deletions

View File

@@ -1,7 +1,7 @@
import argparse
import numpy
from transformers import BertModel
from transformers import AutoModel
def replace_key(key: str) -> str:
@@ -20,7 +20,7 @@ def replace_key(key: str) -> str:
def convert(bert_model: str, mlx_model: str) -> None:
model = BertModel.from_pretrained(bert_model)
model = AutoModel.from_pretrained(bert_model)
# save the tensors
tensors = {
replace_key(key): tensor.numpy() for key, tensor in model.state_dict().items()
@@ -32,14 +32,9 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert BERT weights to MLX.")
parser.add_argument(
"--bert-model",
choices=[
"bert-base-uncased",
"bert-base-cased",
"bert-large-uncased",
"bert-large-cased",
],
type=str,
default="bert-base-uncased",
help="The huggingface name of the BERT model to save.",
help="The huggingface name of the BERT model to save. Any BERT-like model can be specified.",
)
parser.add_argument(
"--mlx-model",