diff --git a/llama/llama.py b/llama/llama.py index af98685d..ad6fd8ce 100644 --- a/llama/llama.py +++ b/llama/llama.py @@ -294,7 +294,7 @@ def few_shot_generate(args): mx.eval(token) prompt_processing = toc("Prompt processing", start) - if len(tokens) >= args.num_tokens: + if len(tokens) >= args.max_tokens: break mx.eval(tokens)