mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Add arg tokens_per_eval for token generation
This commit is contained in:
parent
0bf5d0e3bc
commit
ce9ba916a3
@ -247,6 +247,12 @@ if __name__ == "__main__":
|
|||||||
type=float,
|
type=float,
|
||||||
default=1.0,
|
default=1.0,
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokens_per_eval",
|
||||||
|
help="The batch size of tokens to generate.",
|
||||||
|
type=int,
|
||||||
|
default=10,
|
||||||
|
)
|
||||||
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@ -263,7 +269,7 @@ if __name__ == "__main__":
|
|||||||
for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
|
for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
|
||||||
tokens.append(token)
|
tokens.append(token)
|
||||||
|
|
||||||
if (len(tokens) % 10) == 0:
|
if (len(tokens) % args.tokens_per_eval) == 0:
|
||||||
mx.eval(tokens)
|
mx.eval(tokens)
|
||||||
s = tokenizer.decode([t.item() for t in tokens])
|
s = tokenizer.decode([t.item() for t in tokens])
|
||||||
print(s, end="", flush=True)
|
print(s, end="", flush=True)
|
||||||
|
Loading…
Reference in New Issue
Block a user