From 90d3a15ba20cede1a5f459fa623373f591f247fe Mon Sep 17 00:00:00 2001 From: Juarez Bochi Date: Sun, 17 Dec 2023 08:58:09 -0500 Subject: [PATCH] Stop on eos --- t5/t5.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/t5/t5.py b/t5/t5.py index a60b6246..571abbfb 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -313,10 +313,6 @@ def generate( else: return mx.random.categorical(logits * (1 / temp)) - logits, _ = model(inputs, decoder_inputs) - y = sample(logits[:, -1, :]) - yield y - while True: # TODO: add cache logits, _ = model(inputs, decoder_inputs) @@ -398,21 +394,32 @@ if __name__ == "__main__": exit(0) print("[INFO] Generating with T5...", flush=True) - print(args.prompt, end="", flush=True) + print("Input: ", args.prompt, flush=True) decoder_inputs = mx.array([[config.decoder_start_token_id]]).astype(mx.uint32) tokens = [] for token, _ in zip( - generate(prompt, decoder_inputs, model), range(args.max_tokens) + generate(prompt, decoder_inputs, model, args.temp), + range(args.max_tokens) ): tokens.append(token) if (len(tokens) % 10) == 0: mx.eval(tokens) + eos_index = next( + (i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id), + None, + ) + + if eos_index is not None: + tokens = tokens[:eos_index] + s = tokenizer.decode([t.item() for t in tokens]) print(s, end="", flush=True) tokens = [] + if eos_index is not None: + break mx.eval(tokens) s = tokenizer.decode([t.item() for t in tokens])