From 840c0c36c29baec53449100883183789310a2ae1 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 14 Dec 2023 08:27:44 -0800 Subject: [PATCH] don't drop last tokens --- phi2/model.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/phi2/model.py b/phi2/model.py index a99d3d5d..38199c6c 100644 --- a/phi2/model.py +++ b/phi2/model.py @@ -159,11 +159,8 @@ def generate(prompt: mx.array, model: Phi2, temp: Optional[float] = 0.0): def load_model(): model = Phi2(ModelArgs()) - weights = mx.load("weights.npz") - weights = tree_unflatten(list(weights.items())) - model.update(weights) - + model.update(tree_unflatten(list(weights.items()))) tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True) return model, tokenizer @@ -215,3 +212,7 @@ if __name__ == "__main__": s = tokenizer.decode([t.item() for t in tokens]) print(s, end="", flush=True) tokens = [] + + mx.eval(tokens) + s = tokenizer.decode([t.item() for t in tokens]) + print(s, flush=True)