mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Stop on eos
This commit is contained in:
parent
61fda57eba
commit
90d3a15ba2
19
t5/t5.py
19
t5/t5.py
@ -313,10 +313,6 @@ def generate(
|
||||
else:
|
||||
return mx.random.categorical(logits * (1 / temp))
|
||||
|
||||
logits, _ = model(inputs, decoder_inputs)
|
||||
y = sample(logits[:, -1, :])
|
||||
yield y
|
||||
|
||||
while True:
|
||||
# TODO: add cache
|
||||
logits, _ = model(inputs, decoder_inputs)
|
||||
@ -398,21 +394,32 @@ if __name__ == "__main__":
|
||||
exit(0)
|
||||
|
||||
print("[INFO] Generating with T5...", flush=True)
|
||||
print(args.prompt, end="", flush=True)
|
||||
print("Input: ", args.prompt, flush=True)
|
||||
|
||||
decoder_inputs = mx.array([[config.decoder_start_token_id]]).astype(mx.uint32)
|
||||
|
||||
tokens = []
|
||||
for token, _ in zip(
|
||||
generate(prompt, decoder_inputs, model), range(args.max_tokens)
|
||||
generate(prompt, decoder_inputs, model, args.temp),
|
||||
range(args.max_tokens)
|
||||
):
|
||||
tokens.append(token)
|
||||
|
||||
if (len(tokens) % 10) == 0:
|
||||
mx.eval(tokens)
|
||||
eos_index = next(
|
||||
(i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id),
|
||||
None,
|
||||
)
|
||||
|
||||
if eos_index is not None:
|
||||
tokens = tokens[:eos_index]
|
||||
|
||||
s = tokenizer.decode([t.item() for t in tokens])
|
||||
print(s, end="", flush=True)
|
||||
tokens = []
|
||||
if eos_index is not None:
|
||||
break
|
||||
|
||||
mx.eval(tokens)
|
||||
s = tokenizer.decode([t.item() for t in tokens])
|
||||
|
Loading…
Reference in New Issue
Block a user