mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Merge pull request #52 from ricardo-larosa/mistral_batch_size
Mistral: Pass argument --tokens_per_eval for token generation
This commit is contained in:
commit
2652b4f055
@ -247,6 +247,12 @@ if __name__ == "__main__":
|
||||
type=float,
|
||||
default=1.0,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokens_per_eval",
|
||||
help="The batch size of tokens to generate.",
|
||||
type=int,
|
||||
default=10,
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
||||
|
||||
args = parser.parse_args()
|
||||
@ -263,7 +269,7 @@ if __name__ == "__main__":
|
||||
for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
|
||||
tokens.append(token)
|
||||
|
||||
if (len(tokens) % 10) == 0:
|
||||
if (len(tokens) % args.tokens_per_eval) == 0:
|
||||
mx.eval(tokens)
|
||||
s = tokenizer.decode([t.item() for t in tokens])
|
||||
print(s, end="", flush=True)
|
||||
|
Loading…
Reference in New Issue
Block a user