Stop generating at eos token

This commit is contained in:
devonthomas35 2023-12-14 15:50:59 -08:00 committed by GitHub
parent 67a208b13e
commit 4549dcbbd0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -202,7 +202,11 @@ if __name__ == "__main__":
tokens = []
for token, _ in zip(generate(prompt, model), range(args.max_tokens)):
tokens.append(token)
if token == tokenizer.eos_token_id:
break
else:
tokens.append(token)
if (len(tokens) % 10) == 0:
mx.eval(tokens)