From 3c63be8c5512620abe2bf1e8b34d5739375afbf3 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 3 Dec 2024 06:28:21 -0800 Subject: [PATCH] comments --- llms/mlx_lm/generate.py | 2 +- llms/mlx_lm/utils.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 9e96fbdc..0c1b4acd 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -77,7 +77,7 @@ def setup_arg_parser(): ) parser.add_argument( "--min-tokens-to-keep", - type=float, + type=int, default=DEFAULT_MIN_TOKENS_TO_KEEP, help="Minimum tokens to keep for min-p sampling.", ) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 4d815810..86b786ce 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -183,7 +183,7 @@ def generate_step( prompt: mx.array, model: nn.Module, *, - max_tokens: int = 100, + max_tokens: int = 256, sampler: Optional[Callable[mx.array, mx.array]] = None, logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, max_kv_size: Optional[int] = None, @@ -206,7 +206,8 @@ def generate_step( Args: prompt (mx.array): The input prompt. model (nn.Module): The model to use for generation. - max_tokens (int): The maximum number of tokens. Default: ``100``. + max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite + generator. Default: ``256``. sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a token from a vector of log probabilities. Default: ``None``. logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional): @@ -296,7 +297,7 @@ def generate_step( mx.async_eval(y, logprobs) n = 0 while True: - if n < max_tokens: + if n != max_tokens: next_y, next_logprobs = _step(y) mx.async_eval(next_y, next_logprobs) if n == 0: @@ -307,8 +308,7 @@ def generate_step( yield y.item(), logprobs if n % 256 == 0: mx.metal.clear_cache() - if n < max_tokens: - y, logprobs = next_y, next_logprobs + y, logprobs = next_y, next_logprobs n += 1