mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
Add --model_path
to phi-2 example script (#152)
This commit is contained in:
parent
b6e62caf2e
commit
d8e14c858e
14
phi2/phi2.py
14
phi2/phi2.py
@ -1,4 +1,5 @@
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
from mlx.utils import tree_unflatten
|
||||
@ -154,9 +155,10 @@ def generate(prompt: mx.array, model: Phi2, temp: Optional[float] = 0.0):
|
||||
yield y
|
||||
|
||||
|
||||
def load_model():
|
||||
def load_model(model_path: str):
|
||||
model = Phi2(ModelArgs())
|
||||
weights = mx.load("weights.npz")
|
||||
model_path = Path(model_path)
|
||||
weights = mx.load(str(model_path / "weights.npz"))
|
||||
model.update(tree_unflatten(list(weights.items())))
|
||||
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
|
||||
return model, tokenizer
|
||||
@ -164,6 +166,12 @@ def load_model():
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Phi-2 inference script")
|
||||
parser.add_argument(
|
||||
"--model_path",
|
||||
type=str,
|
||||
default="phi-2",
|
||||
help="The path to the model weights",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
help="The message to be processed by the model",
|
||||
@ -187,7 +195,7 @@ if __name__ == "__main__":
|
||||
|
||||
mx.random.seed(args.seed)
|
||||
|
||||
model, tokenizer = load_model()
|
||||
model, tokenizer = load_model(args.model_path)
|
||||
|
||||
prompt = tokenizer(
|
||||
args.prompt,
|
||||
|
Loading…
Reference in New Issue
Block a user