some cleanup

This commit is contained in:
Awni Hannun
2025-01-09 12:21:31 -08:00
parent 761b2c9886
commit 2797c438bb
13 changed files with 46 additions and 7999 deletions

View File

@@ -1,6 +1,6 @@
import argparse
import numpy
import mlx.core as mx
from transformers import AutoModel
@@ -23,9 +23,9 @@ def convert(bert_model: str, mlx_model: str) -> None:
model = AutoModel.from_pretrained(bert_model)
# save the tensors
tensors = {
replace_key(key): tensor.numpy() for key, tensor in model.state_dict().items()
replace_key(key): mx.array(tensor) for key, tensor in model.state_dict().items()
}
numpy.savez(mlx_model, **tensors)
mx.save_safetensors(mlx_model, tensors)
if __name__ == "__main__":
@@ -39,7 +39,7 @@ if __name__ == "__main__":
parser.add_argument(
"--mlx-model",
type=str,
default="weights/bert-base-uncased.npz",
default="bert-base-uncased.safetensors",
help="The output path for the MLX BERT weights.",
)
args = parser.parse_args()

View File

@@ -136,10 +136,7 @@ def load_model(
def run(bert_model: str, mlx_model: str, batch: List[str]):
model, tokenizer = load_model(bert_model, mlx_model)
tokens = tokenizer(batch, return_tensors="np", padding=True)
tokens = {key: mx.array(v) for key, v in tokens.items()}
tokens = tokenizer(batch, return_tensors="mlx", padding=True)
return model(**tokens)
@@ -149,13 +146,13 @@ if __name__ == "__main__":
"--bert-model",
type=str,
default="bert-base-uncased",
help="The huggingface name of the BERT model to save.",
help="The huggingface name of the BERT model.",
)
parser.add_argument(
"--mlx-model",
type=str,
default="weights/bert-base-uncased.npz",
help="The path of the stored MLX BERT weights (npz file).",
default="bert-base-uncased.safetensors",
help="The path of the stored MLX BERT weights.",
)
parser.add_argument(
"--text",

View File

@@ -1,3 +1,2 @@
mlx>=0.0.5
transformers
numpy

View File

@@ -29,8 +29,8 @@ if __name__ == "__main__":
parser.add_argument(
"--mlx-model",
type=str,
default="weights/bert-base-uncased.npz",
help="The path of the stored MLX BERT weights (npz file).",
default="bert-base-uncased.safetensors",
help="The path of the stored MLX BERT weights.",
)
parser.add_argument(
"--text",

View File

@@ -1 +0,0 @@
*.npz