Refactor EOS check

This commit is contained in:
devonthomas35 2023-12-14 21:11:23 -08:00 committed by GitHub
parent d7d7aabded
commit d74d9453dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -202,16 +202,20 @@ if __name__ == "__main__":
tokens = []
for token, _ in zip(generate(prompt, model), range(args.max_tokens)):
if token == tokenizer.eos_token_id:
break
else:
tokens.append(token)
tokens.append(token)
if (len(tokens) % 10) == 0:
mx.eval(tokens)
eos_index = next((i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id), None)
if eos_index is not None:
tokens = tokens[:eos_index]
s = tokenizer.decode([t.item() for t in tokens])
print(s, end="", flush=True)
tokens = []
if eos_index is not None:
break
mx.eval(tokens)
s = tokenizer.decode([t.item() for t in tokens])