Add --model_path to phi-2 example script (#152)

This commit is contained in:
Pedro Cuenca 2023-12-20 15:14:35 +01:00 committed by GitHub
parent b6e62caf2e
commit d8e14c858e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,4 +1,5 @@
import argparse import argparse
from pathlib import Path
from typing import Optional from typing import Optional
from dataclasses import dataclass from dataclasses import dataclass
from mlx.utils import tree_unflatten from mlx.utils import tree_unflatten
@ -154,9 +155,10 @@ def generate(prompt: mx.array, model: Phi2, temp: Optional[float] = 0.0):
yield y yield y
def load_model(): def load_model(model_path: str):
model = Phi2(ModelArgs()) model = Phi2(ModelArgs())
weights = mx.load("weights.npz") model_path = Path(model_path)
weights = mx.load(str(model_path / "weights.npz"))
model.update(tree_unflatten(list(weights.items()))) model.update(tree_unflatten(list(weights.items())))
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
return model, tokenizer return model, tokenizer
@ -164,6 +166,12 @@ def load_model():
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Phi-2 inference script") parser = argparse.ArgumentParser(description="Phi-2 inference script")
parser.add_argument(
"--model_path",
type=str,
default="phi-2",
help="The path to the model weights",
)
parser.add_argument( parser.add_argument(
"--prompt", "--prompt",
help="The message to be processed by the model", help="The message to be processed by the model",
@ -187,7 +195,7 @@ if __name__ == "__main__":
mx.random.seed(args.seed) mx.random.seed(args.seed)
model, tokenizer = load_model() model, tokenizer = load_model(args.model_path)
prompt = tokenizer( prompt = tokenizer(
args.prompt, args.prompt,