Stop on eos

This commit is contained in:
Juarez Bochi 2023-12-17 08:58:09 -05:00
parent 61fda57eba
commit 90d3a15ba2
No known key found for this signature in database
GPG Key ID: 34CCBB77DC8BEBB6

View File

@ -313,10 +313,6 @@ def generate(
else:
return mx.random.categorical(logits * (1 / temp))
logits, _ = model(inputs, decoder_inputs)
y = sample(logits[:, -1, :])
yield y
while True:
# TODO: add cache
logits, _ = model(inputs, decoder_inputs)
@ -398,21 +394,32 @@ if __name__ == "__main__":
exit(0)
print("[INFO] Generating with T5...", flush=True)
print(args.prompt, end="", flush=True)
print("Input: ", args.prompt, flush=True)
decoder_inputs = mx.array([[config.decoder_start_token_id]]).astype(mx.uint32)
tokens = []
for token, _ in zip(
generate(prompt, decoder_inputs, model), range(args.max_tokens)
generate(prompt, decoder_inputs, model, args.temp),
range(args.max_tokens)
):
tokens.append(token)
if (len(tokens) % 10) == 0:
mx.eval(tokens)
eos_index = next(
(i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id),
None,
)
if eos_index is not None:
tokens = tokens[:eos_index]
s = tokenizer.decode([t.item() for t in tokens])
print(s, end="", flush=True)
tokens = []
if eos_index is not None:
break
mx.eval(tokens)
s = tokenizer.decode([t.item() for t in tokens])