mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
some cleanup
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import argparse
|
||||
|
||||
import numpy
|
||||
import mlx.core as mx
|
||||
from transformers import AutoModel
|
||||
|
||||
|
||||
@@ -23,9 +23,9 @@ def convert(bert_model: str, mlx_model: str) -> None:
|
||||
model = AutoModel.from_pretrained(bert_model)
|
||||
# save the tensors
|
||||
tensors = {
|
||||
replace_key(key): tensor.numpy() for key, tensor in model.state_dict().items()
|
||||
replace_key(key): mx.array(tensor) for key, tensor in model.state_dict().items()
|
||||
}
|
||||
numpy.savez(mlx_model, **tensors)
|
||||
mx.save_safetensors(mlx_model, tensors)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -39,7 +39,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--mlx-model",
|
||||
type=str,
|
||||
default="weights/bert-base-uncased.npz",
|
||||
default="bert-base-uncased.safetensors",
|
||||
help="The output path for the MLX BERT weights.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -136,10 +136,7 @@ def load_model(
|
||||
|
||||
def run(bert_model: str, mlx_model: str, batch: List[str]):
|
||||
model, tokenizer = load_model(bert_model, mlx_model)
|
||||
|
||||
tokens = tokenizer(batch, return_tensors="np", padding=True)
|
||||
tokens = {key: mx.array(v) for key, v in tokens.items()}
|
||||
|
||||
tokens = tokenizer(batch, return_tensors="mlx", padding=True)
|
||||
return model(**tokens)
|
||||
|
||||
|
||||
@@ -149,13 +146,13 @@ if __name__ == "__main__":
|
||||
"--bert-model",
|
||||
type=str,
|
||||
default="bert-base-uncased",
|
||||
help="The huggingface name of the BERT model to save.",
|
||||
help="The huggingface name of the BERT model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mlx-model",
|
||||
type=str,
|
||||
default="weights/bert-base-uncased.npz",
|
||||
help="The path of the stored MLX BERT weights (npz file).",
|
||||
default="bert-base-uncased.safetensors",
|
||||
help="The path of the stored MLX BERT weights.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text",
|
||||
|
||||
@@ -1,3 +1,2 @@
|
||||
mlx>=0.0.5
|
||||
transformers
|
||||
numpy
|
||||
|
||||
@@ -29,8 +29,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--mlx-model",
|
||||
type=str,
|
||||
default="weights/bert-base-uncased.npz",
|
||||
help="The path of the stored MLX BERT weights (npz file).",
|
||||
default="bert-base-uncased.safetensors",
|
||||
help="The path of the stored MLX BERT weights.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text",
|
||||
|
||||
1
bert/weights/.gitignore
vendored
1
bert/weights/.gitignore
vendored
@@ -1 +0,0 @@
|
||||
*.npz
|
||||
Reference in New Issue
Block a user