diff --git a/llms/mixtral/mixtral.py b/llms/mixtral/mixtral.py index 8f715180..6aee4031 100644 --- a/llms/mixtral/mixtral.py +++ b/llms/mixtral/mixtral.py @@ -329,9 +329,16 @@ if __name__ == "__main__": if (len(tokens) % 10) == 0: mx.eval(tokens) + eos_index = next( + (i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_id), None + ) + if eos_index is not None: + tokens = tokens[:eos_index] s = tokenizer.decode([t.item() for t in tokens]) print(s, end="", flush=True) tokens = [] + if eos_index is not None: + break mx.eval(tokens) s = tokenizer.decode([t.item() for t in tokens])