diff --git a/phi2/phi2.py b/phi2/phi2.py index 555ee232..1e57d157 100644 --- a/phi2/phi2.py +++ b/phi2/phi2.py @@ -1,4 +1,5 @@ import argparse +from pathlib import Path from typing import Optional from dataclasses import dataclass from mlx.utils import tree_unflatten @@ -154,9 +155,10 @@ def generate(prompt: mx.array, model: Phi2, temp: Optional[float] = 0.0): yield y -def load_model(): +def load_model(model_path: str): model = Phi2(ModelArgs()) - weights = mx.load("weights.npz") + model_path = Path(model_path) + weights = mx.load(str(model_path / "weights.npz")) model.update(tree_unflatten(list(weights.items()))) tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True) return model, tokenizer @@ -164,6 +166,12 @@ def load_model(): if __name__ == "__main__": parser = argparse.ArgumentParser(description="Phi-2 inference script") + parser.add_argument( + "--model_path", + type=str, + default="phi-2", + help="The path to the model weights", + ) parser.add_argument( "--prompt", help="The message to be processed by the model", @@ -187,7 +195,7 @@ if __name__ == "__main__": mx.random.seed(args.seed) - model, tokenizer = load_model() + model, tokenizer = load_model(args.model_path) prompt = tokenizer( args.prompt,