From 939086e6a3e92770ae9a024f8874a8c79f31d560 Mon Sep 17 00:00:00 2001 From: devonthomas35 <30363743+devonthomas35@users.noreply.github.com> Date: Sat, 23 Dec 2023 21:25:42 -0800 Subject: [PATCH] Mixtral: Stop at EOS token (#183) * Stop at EOS token * Precommit format files * Fix precommit hooks * Fix precommit hooks --- llms/mixtral/mixtral.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/llms/mixtral/mixtral.py b/llms/mixtral/mixtral.py index 8f715180..6aee4031 100644 --- a/llms/mixtral/mixtral.py +++ b/llms/mixtral/mixtral.py @@ -329,9 +329,16 @@ if __name__ == "__main__": if (len(tokens) % 10) == 0: mx.eval(tokens) + eos_index = next( + (i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_id), None + ) + if eos_index is not None: + tokens = tokens[:eos_index] s = tokenizer.decode([t.item() for t in tokens]) print(s, end="", flush=True) tokens = [] + if eos_index is not None: + break mx.eval(tokens) s = tokenizer.decode([t.item() for t in tokens])