diff --git a/phi2/model.py b/phi2/model.py index a99d3d5d..38199c6c 100644 --- a/phi2/model.py +++ b/phi2/model.py @@ -159,11 +159,8 @@ def generate(prompt: mx.array, model: Phi2, temp: Optional[float] = 0.0): def load_model(): model = Phi2(ModelArgs()) - weights = mx.load("weights.npz") - weights = tree_unflatten(list(weights.items())) - model.update(weights) - + model.update(tree_unflatten(list(weights.items()))) tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True) return model, tokenizer @@ -215,3 +212,7 @@ if __name__ == "__main__": s = tokenizer.decode([t.item() for t in tokens]) print(s, end="", flush=True) tokens = [] + + mx.eval(tokens) + s = tokenizer.decode([t.item() for t in tokens]) + print(s, flush=True)