chore(mlx-lm): add max token arg for mlx_lm.chat

This commit is contained in:
Anchen 2024-11-04 07:14:19 +08:00
parent 331148d8ec
commit e0e6847d20

View File

@ -11,6 +11,7 @@ from .utils import load, stream_generate
DEFAULT_TEMP = 0.0 DEFAULT_TEMP = 0.0
DEFAULT_TOP_P = 1.0 DEFAULT_TOP_P = 1.0
DEFAULT_SEED = 0 DEFAULT_SEED = 0
DEFAULT_MAX_TOKENS = 100
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
@ -41,6 +42,13 @@ def setup_arg_parser():
help="Set the maximum key-value cache size", help="Set the maximum key-value cache size",
default=None, default=None,
) )
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=DEFAULT_MAX_TOKENS,
help="Maximum number of tokens to generate",
)
return parser return parser
@ -70,6 +78,7 @@ def main():
model, model,
tokenizer, tokenizer,
prompt, prompt,
args.max_tokens,
temp=args.temp, temp=args.temp,
top_p=args.top_p, top_p=args.top_p,
prompt_cache=prompt_cache, prompt_cache=prompt_cache,