Concatenate tokens

This commit is contained in:
Juarez Bochi 2023-12-17 08:51:16 -05:00
parent daea1dcddf
commit 152e85fade
No known key found for this signature in database
GPG Key ID: 34CCBB77DC8BEBB6

View File

@ -321,10 +321,10 @@ def generate(
yield y
while True:
# logits, cache = model(y[:, None], cache=cache)
# TODO: add cache
logits, _ = model(inputs, decoder_inputs)
y = sample(logits.squeeze(1))
decoder_inputs = mx.concat(decoder_inputs, y, dim=1)
y = mx.expand_dims(sample(logits[:, -1, :]), 0)
decoder_inputs = mx.concatenate([decoder_inputs, y], axis=1)
yield y
@ -403,7 +403,7 @@ if __name__ == "__main__":
print("[INFO] Generating with T5...", flush=True)
print(args.prompt, end="", flush=True)
decoder_inputs = mx.array([[config.decoder_start_token_id]])
decoder_inputs = mx.array([[config.decoder_start_token_id]]).astype(mx.uint32)
tokens = []
for token, _ in zip(