From 152e85fade24265bba3abab771ee671451d83d98 Mon Sep 17 00:00:00 2001 From: Juarez Bochi Date: Sun, 17 Dec 2023 08:51:16 -0500 Subject: [PATCH] Concatenate tokens --- t5/t5.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/t5/t5.py b/t5/t5.py index c165a964..d061e225 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -321,10 +321,10 @@ def generate( yield y while True: - # logits, cache = model(y[:, None], cache=cache) + # TODO: add cache logits, _ = model(inputs, decoder_inputs) - y = sample(logits.squeeze(1)) - decoder_inputs = mx.concat(decoder_inputs, y, dim=1) + y = mx.expand_dims(sample(logits[:, -1, :]), 0) + decoder_inputs = mx.concatenate([decoder_inputs, y], axis=1) yield y @@ -403,7 +403,7 @@ if __name__ == "__main__": print("[INFO] Generating with T5...", flush=True) print(args.prompt, end="", flush=True) - decoder_inputs = mx.array([[config.decoder_start_token_id]]) + decoder_inputs = mx.array([[config.decoder_start_token_id]]).astype(mx.uint32) tokens = [] for token, _ in zip(