From 4549dcbbd03f48c99b14fccd82908c76c48adcd8 Mon Sep 17 00:00:00 2001 From: devonthomas35 <30363743+devonthomas35@users.noreply.github.com> Date: Thu, 14 Dec 2023 15:50:59 -0800 Subject: [PATCH 1/3] Stop generating at eos token --- phi2/phi2.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/phi2/phi2.py b/phi2/phi2.py index 7973c33d..ede79ea2 100644 --- a/phi2/phi2.py +++ b/phi2/phi2.py @@ -202,7 +202,11 @@ if __name__ == "__main__": tokens = [] for token, _ in zip(generate(prompt, model), range(args.max_tokens)): - tokens.append(token) + + if token == tokenizer.eos_token_id: + break + else: + tokens.append(token) if (len(tokens) % 10) == 0: mx.eval(tokens) From d7d7aabded3b52d36c5f3a3675553d5651639b6f Mon Sep 17 00:00:00 2001 From: devonthomas35 <30363743+devonthomas35@users.noreply.github.com> Date: Thu, 14 Dec 2023 15:52:22 -0800 Subject: [PATCH 2/3] Remove unnecessary return --- phi2/phi2.py | 1 - 1 file changed, 1 deletion(-) diff --git a/phi2/phi2.py b/phi2/phi2.py index ede79ea2..2d3f792a 100644 --- a/phi2/phi2.py +++ b/phi2/phi2.py @@ -202,7 +202,6 @@ if __name__ == "__main__": tokens = [] for token, _ in zip(generate(prompt, model), range(args.max_tokens)): - if token == tokenizer.eos_token_id: break else: From d74d9453ddcace552828816a4114b0234febf837 Mon Sep 17 00:00:00 2001 From: devonthomas35 <30363743+devonthomas35@users.noreply.github.com> Date: Thu, 14 Dec 2023 21:11:23 -0800 Subject: [PATCH 3/3] Refactor EOS check --- phi2/phi2.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/phi2/phi2.py b/phi2/phi2.py index 2d3f792a..4a9ed30e 100644 --- a/phi2/phi2.py +++ b/phi2/phi2.py @@ -202,16 +202,20 @@ if __name__ == "__main__": tokens = [] for token, _ in zip(generate(prompt, model), range(args.max_tokens)): - if token == tokenizer.eos_token_id: - break - else: - tokens.append(token) + tokens.append(token) if (len(tokens) % 10) == 0: mx.eval(tokens) + eos_index = next((i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id), None) + + if eos_index is not None: + tokens = tokens[:eos_index] + s = tokenizer.decode([t.item() for t in tokens]) print(s, end="", flush=True) tokens = [] + if eos_index is not None: + break mx.eval(tokens) s = tokenizer.decode([t.item() for t in tokens])