mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
Add --model_path
to phi-2 example script (#152)
This commit is contained in:
parent
b6e62caf2e
commit
d8e14c858e
14
phi2/phi2.py
14
phi2/phi2.py
@ -1,4 +1,5 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from mlx.utils import tree_unflatten
|
from mlx.utils import tree_unflatten
|
||||||
@ -154,9 +155,10 @@ def generate(prompt: mx.array, model: Phi2, temp: Optional[float] = 0.0):
|
|||||||
yield y
|
yield y
|
||||||
|
|
||||||
|
|
||||||
def load_model():
|
def load_model(model_path: str):
|
||||||
model = Phi2(ModelArgs())
|
model = Phi2(ModelArgs())
|
||||||
weights = mx.load("weights.npz")
|
model_path = Path(model_path)
|
||||||
|
weights = mx.load(str(model_path / "weights.npz"))
|
||||||
model.update(tree_unflatten(list(weights.items())))
|
model.update(tree_unflatten(list(weights.items())))
|
||||||
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
@ -164,6 +166,12 @@ def load_model():
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Phi-2 inference script")
|
parser = argparse.ArgumentParser(description="Phi-2 inference script")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_path",
|
||||||
|
type=str,
|
||||||
|
default="phi-2",
|
||||||
|
help="The path to the model weights",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--prompt",
|
"--prompt",
|
||||||
help="The message to be processed by the model",
|
help="The message to be processed by the model",
|
||||||
@ -187,7 +195,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
mx.random.seed(args.seed)
|
mx.random.seed(args.seed)
|
||||||
|
|
||||||
model, tokenizer = load_model()
|
model, tokenizer = load_model(args.model_path)
|
||||||
|
|
||||||
prompt = tokenizer(
|
prompt = tokenizer(
|
||||||
args.prompt,
|
args.prompt,
|
||||||
|
Loading…
Reference in New Issue
Block a user