From e6d35301bde6f0a046a3e5ad38c943d797b0c6e1 Mon Sep 17 00:00:00 2001 From: N8 Date: Thu, 31 Oct 2024 02:37:14 -0400 Subject: [PATCH] smol modification --- llms/mlx_lm/utils.py | 36 ++++++++---------------------------- 1 file changed, 8 insertions(+), 28 deletions(-) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 720151e9..c26cdeb0 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -146,7 +146,7 @@ def generate_step( prompt (mx.array): The input prompt. model (nn.Module): The model to use for generation. temp (float): The temperature for sampling, if 0 the argmax is used. Default: ``0``. - repetition_penalty (float, optional): The penalty factor for repeating + repetition_penalty (float, optional): The penalty factor for repeating tokens. repetition_context_size (int, optional): The number of tokens to consider for repetition penalty. Default: ``20``. @@ -277,38 +277,18 @@ def stream_generate( **kwargs, ) -> Union[str, Generator[str, None, None]]: """ - A generator producing token ids based on the given prompt from the model. + A generator producing text based on the given prompt from the model. Args: prompt (mx.array): The input prompt. model (nn.Module): The model to use for generation. - temp (float): The temperature for sampling, if 0 the argmax is used. - Default: ``0``. - repetition_penalty (float, optional): The penalty factor for repeating - tokens. - repetition_context_size (int, optional): The number of tokens to - consider for repetition penalty. Default: ``20``. - top_p (float, optional): Nulceus sampling, higher means model considers - more less likely words. - min_p (float, optional): The minimum value (scaled by the top token's - probability) that a token probability must have to be considered. - min_tokens_to_keep (int, optional): Minimum number of tokens that cannot - be filtered by min_p sampling. - prefill_step_size (int): Step size for processing the prompt. - max_kv_size (int, optional): Maximum size of the key-value cache. Old - entries (except the first 4 tokens) will be overwritten. - prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if - provided, the cache will be updated in place. - logit_bias (dictionary, optional): Additive logit bias. - logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional): - A list of functions that take tokens and logits and return the processed - logits. Default: ``None``. - max_tokens_per_sec (float, optional): If set, limits generation speed to approximately - this many tokens per second by adding delays between tokens. Useful for thermal/power - management. Default: None (no limit). + max_tokens (int): The maximum number of tokens. Default: ``100``. + max_tokens_per_sec (float, optional): If set, limits generation speed to approximately max_tokens_per_sec. May go slightly over this limit. + kwargs: The remaining options get passed to :func:`generate_step`. + See :func:`generate_step` for more details. + Yields: - Generator[Tuple[mx.array, mx.array], None, None]: A generator producing - one token and a vector of log probabilities. + Generator[Tuple[mx.array, mx.array]]: A generator producing text. """ if not isinstance(tokenizer, TokenizerWrapper):