mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:26:37 +08:00
Concatenate tokens
This commit is contained in:
parent
daea1dcddf
commit
152e85fade
8
t5/t5.py
8
t5/t5.py
@ -321,10 +321,10 @@ def generate(
|
||||
yield y
|
||||
|
||||
while True:
|
||||
# logits, cache = model(y[:, None], cache=cache)
|
||||
# TODO: add cache
|
||||
logits, _ = model(inputs, decoder_inputs)
|
||||
y = sample(logits.squeeze(1))
|
||||
decoder_inputs = mx.concat(decoder_inputs, y, dim=1)
|
||||
y = mx.expand_dims(sample(logits[:, -1, :]), 0)
|
||||
decoder_inputs = mx.concatenate([decoder_inputs, y], axis=1)
|
||||
yield y
|
||||
|
||||
|
||||
@ -403,7 +403,7 @@ if __name__ == "__main__":
|
||||
print("[INFO] Generating with T5...", flush=True)
|
||||
print(args.prompt, end="", flush=True)
|
||||
|
||||
decoder_inputs = mx.array([[config.decoder_start_token_id]])
|
||||
decoder_inputs = mx.array([[config.decoder_start_token_id]]).astype(mx.uint32)
|
||||
|
||||
tokens = []
|
||||
for token, _ in zip(
|
||||
|
Loading…
Reference in New Issue
Block a user