diff --git a/phi2/phi2.py b/phi2/phi2.py index 7973c33d..4a9ed30e 100644 --- a/phi2/phi2.py +++ b/phi2/phi2.py @@ -206,9 +206,16 @@ if __name__ == "__main__": if (len(tokens) % 10) == 0: mx.eval(tokens) + eos_index = next((i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id), None) + + if eos_index is not None: + tokens = tokens[:eos_index] + s = tokenizer.decode([t.item() for t in tokens]) print(s, end="", flush=True) tokens = [] + if eos_index is not None: + break mx.eval(tokens) s = tokenizer.decode([t.item() for t in tokens])