mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-24 06:28:07 +08:00
Enable more BERT models (#580)
* Update convert.py * Update model.py * Update test.py * Update model.py * Update convert.py * Add files via upload * Update convert.py * format * nit * nit --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import argparse
|
||||
|
||||
import numpy
|
||||
from transformers import BertModel
|
||||
from transformers import AutoModel
|
||||
|
||||
|
||||
def replace_key(key: str) -> str:
|
||||
@@ -20,7 +20,7 @@ def replace_key(key: str) -> str:
|
||||
|
||||
|
||||
def convert(bert_model: str, mlx_model: str) -> None:
|
||||
model = BertModel.from_pretrained(bert_model)
|
||||
model = AutoModel.from_pretrained(bert_model)
|
||||
# save the tensors
|
||||
tensors = {
|
||||
replace_key(key): tensor.numpy() for key, tensor in model.state_dict().items()
|
||||
@@ -32,14 +32,9 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert BERT weights to MLX.")
|
||||
parser.add_argument(
|
||||
"--bert-model",
|
||||
choices=[
|
||||
"bert-base-uncased",
|
||||
"bert-base-cased",
|
||||
"bert-large-uncased",
|
||||
"bert-large-cased",
|
||||
],
|
||||
type=str,
|
||||
default="bert-base-uncased",
|
||||
help="The huggingface name of the BERT model to save.",
|
||||
help="The huggingface name of the BERT model to save. Any BERT-like model can be specified.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mlx-model",
|
||||
|
||||
Reference in New Issue
Block a user