From ce9ba916a3072dc9409861c47655d7e4d7ebb2db Mon Sep 17 00:00:00 2001 From: ricardo-larosa Date: Sat, 9 Dec 2023 19:43:44 +0100 Subject: [PATCH] Add arg tokens_per_eval for token generation --- mistral/mistral.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/mistral/mistral.py b/mistral/mistral.py index 6a6447bc..767b5936 100644 --- a/mistral/mistral.py +++ b/mistral/mistral.py @@ -247,6 +247,12 @@ if __name__ == "__main__": type=float, default=1.0, ) + parser.add_argument( + "--tokens_per_eval", + help="The batch size of tokens to generate.", + type=int, + default=10, + ) parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") args = parser.parse_args() @@ -263,7 +269,7 @@ if __name__ == "__main__": for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)): tokens.append(token) - if (len(tokens) % 10) == 0: + if (len(tokens) % args.tokens_per_eval) == 0: mx.eval(tokens) s = tokenizer.decode([t.item() for t in tokens]) print(s, end="", flush=True)