From d74d9453ddcace552828816a4114b0234febf837 Mon Sep 17 00:00:00 2001 From: devonthomas35 <30363743+devonthomas35@users.noreply.github.com> Date: Thu, 14 Dec 2023 21:11:23 -0800 Subject: [PATCH] Refactor EOS check --- phi2/phi2.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/phi2/phi2.py b/phi2/phi2.py index 2d3f792a..4a9ed30e 100644 --- a/phi2/phi2.py +++ b/phi2/phi2.py @@ -202,16 +202,20 @@ if __name__ == "__main__": tokens = [] for token, _ in zip(generate(prompt, model), range(args.max_tokens)): - if token == tokenizer.eos_token_id: - break - else: - tokens.append(token) + tokens.append(token) if (len(tokens) % 10) == 0: mx.eval(tokens) + eos_index = next((i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id), None) + + if eos_index is not None: + tokens = tokens[:eos_index] + s = tokenizer.decode([t.item() for t in tokens]) print(s, end="", flush=True) tokens = [] + if eos_index is not None: + break mx.eval(tokens) s = tokenizer.decode([t.item() for t in tokens])