mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Stop on eos
This commit is contained in:
parent
61fda57eba
commit
90d3a15ba2
19
t5/t5.py
19
t5/t5.py
@ -313,10 +313,6 @@ def generate(
|
|||||||
else:
|
else:
|
||||||
return mx.random.categorical(logits * (1 / temp))
|
return mx.random.categorical(logits * (1 / temp))
|
||||||
|
|
||||||
logits, _ = model(inputs, decoder_inputs)
|
|
||||||
y = sample(logits[:, -1, :])
|
|
||||||
yield y
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
# TODO: add cache
|
# TODO: add cache
|
||||||
logits, _ = model(inputs, decoder_inputs)
|
logits, _ = model(inputs, decoder_inputs)
|
||||||
@ -398,21 +394,32 @@ if __name__ == "__main__":
|
|||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
print("[INFO] Generating with T5...", flush=True)
|
print("[INFO] Generating with T5...", flush=True)
|
||||||
print(args.prompt, end="", flush=True)
|
print("Input: ", args.prompt, flush=True)
|
||||||
|
|
||||||
decoder_inputs = mx.array([[config.decoder_start_token_id]]).astype(mx.uint32)
|
decoder_inputs = mx.array([[config.decoder_start_token_id]]).astype(mx.uint32)
|
||||||
|
|
||||||
tokens = []
|
tokens = []
|
||||||
for token, _ in zip(
|
for token, _ in zip(
|
||||||
generate(prompt, decoder_inputs, model), range(args.max_tokens)
|
generate(prompt, decoder_inputs, model, args.temp),
|
||||||
|
range(args.max_tokens)
|
||||||
):
|
):
|
||||||
tokens.append(token)
|
tokens.append(token)
|
||||||
|
|
||||||
if (len(tokens) % 10) == 0:
|
if (len(tokens) % 10) == 0:
|
||||||
mx.eval(tokens)
|
mx.eval(tokens)
|
||||||
|
eos_index = next(
|
||||||
|
(i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if eos_index is not None:
|
||||||
|
tokens = tokens[:eos_index]
|
||||||
|
|
||||||
s = tokenizer.decode([t.item() for t in tokens])
|
s = tokenizer.decode([t.item() for t in tokens])
|
||||||
print(s, end="", flush=True)
|
print(s, end="", flush=True)
|
||||||
tokens = []
|
tokens = []
|
||||||
|
if eos_index is not None:
|
||||||
|
break
|
||||||
|
|
||||||
mx.eval(tokens)
|
mx.eval(tokens)
|
||||||
s = tokenizer.decode([t.item() for t in tokens])
|
s = tokenizer.decode([t.item() for t in tokens])
|
||||||
|
Loading…
Reference in New Issue
Block a user