smol modification

This commit is contained in:
N8 2024-10-31 02:37:14 -04:00
parent 7e4413b1dd
commit e6d35301bd

View File

@ -146,7 +146,7 @@ def generate_step(
prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation.
temp (float): The temperature for sampling, if 0 the argmax is used. Default: ``0``.
repetition_penalty (float, optional): The penalty factor for repeating
repetition_penalty (float, optional): The penalty factor for repeating
tokens.
repetition_context_size (int, optional): The number of tokens to
consider for repetition penalty. Default: ``20``.
@ -277,38 +277,18 @@ def stream_generate(
**kwargs,
) -> Union[str, Generator[str, None, None]]:
"""
A generator producing token ids based on the given prompt from the model.
A generator producing text based on the given prompt from the model.
Args:
prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation.
temp (float): The temperature for sampling, if 0 the argmax is used.
Default: ``0``.
repetition_penalty (float, optional): The penalty factor for repeating
tokens.
repetition_context_size (int, optional): The number of tokens to
consider for repetition penalty. Default: ``20``.
top_p (float, optional): Nulceus sampling, higher means model considers
more less likely words.
min_p (float, optional): The minimum value (scaled by the top token's
probability) that a token probability must have to be considered.
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
be filtered by min_p sampling.
prefill_step_size (int): Step size for processing the prompt.
max_kv_size (int, optional): Maximum size of the key-value cache. Old
entries (except the first 4 tokens) will be overwritten.
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
provided, the cache will be updated in place.
logit_bias (dictionary, optional): Additive logit bias.
logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional):
A list of functions that take tokens and logits and return the processed
logits. Default: ``None``.
max_tokens_per_sec (float, optional): If set, limits generation speed to approximately
this many tokens per second by adding delays between tokens. Useful for thermal/power
management. Default: None (no limit).
max_tokens (int): The maximum number of tokens. Default: ``100``.
max_tokens_per_sec (float, optional): If set, limits generation speed to approximately max_tokens_per_sec. May go slightly over this limit.
kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details.
Yields:
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
one token and a vector of log probabilities.
Generator[Tuple[mx.array, mx.array]]: A generator producing text.
"""
if not isinstance(tokenizer, TokenizerWrapper):