diff --git a/t5/t5.py b/t5/t5.py index a60b6246..571abbfb 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -313,10 +313,6 @@ def generate( else: return mx.random.categorical(logits * (1 / temp)) - logits, _ = model(inputs, decoder_inputs) - y = sample(logits[:, -1, :]) - yield y - while True: # TODO: add cache logits, _ = model(inputs, decoder_inputs) @@ -398,21 +394,32 @@ if __name__ == "__main__": exit(0) print("[INFO] Generating with T5...", flush=True) - print(args.prompt, end="", flush=True) + print("Input: ", args.prompt, flush=True) decoder_inputs = mx.array([[config.decoder_start_token_id]]).astype(mx.uint32) tokens = [] for token, _ in zip( - generate(prompt, decoder_inputs, model), range(args.max_tokens) + generate(prompt, decoder_inputs, model, args.temp), + range(args.max_tokens) ): tokens.append(token) if (len(tokens) % 10) == 0: mx.eval(tokens) + eos_index = next( + (i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id), + None, + ) + + if eos_index is not None: + tokens = tokens[:eos_index] + s = tokenizer.decode([t.item() for t in tokens]) print(s, end="", flush=True) tokens = [] + if eos_index is not None: + break mx.eval(tokens) s = tokenizer.decode([t.item() for t in tokens])