add max token limit

This commit is contained in:
N8 2024-10-31 02:20:55 -04:00
parent 8fe9539af7
commit 7e4413b1dd
3 changed files with 77 additions and 22 deletions

View File

@ -29,7 +29,16 @@ def setup_arg_parser():
help="Optional path for the trained adapter weights and config.",
)
parser.add_argument(
"--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature"
"--max-tokens-per-sec",
type=int,
help="Maximum tokens to generate per second.",
default=None,
)
parser.add_argument(
"--max-tokens-per-sec",
type=int,
default=None,
help="Maximum tokens to generate per second",
)
parser.add_argument(
"--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p"
@ -56,7 +65,7 @@ def main():
tokenizer_config={"trust_remote_code": True},
)
print(f"[INFO] Starting chat sessiong with {args.model}. To exit, enter 'q'.")
print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.")
prompt_cache = make_prompt_cache(model, args.max_kv_size)
while True:
query = input(">> ")
@ -72,7 +81,9 @@ def main():
prompt,
temp=args.temp,
top_p=args.top_p,
max_tokens_per_sec=args.max_tokens_per_sec,
prompt_cache=prompt_cache,
max_tokens=4096 # Ensure this is set to a reasonable limit
):
print(response, flush=True, end="")
print()

View File

@ -9,8 +9,8 @@ import mlx.core as mx
from .models.cache import load_prompt_cache
from .utils import generate, load
DEFAULT_PROMPT = "hello"
DEFAULT_MAX_TOKENS = 100
DEFAULT_PROMPT = "Tell me a story!"
DEFAULT_MAX_TOKENS = 1000
DEFAULT_TEMP = 0.0
DEFAULT_TOP_P = 1.0
DEFAULT_SEED = 0
@ -61,6 +61,12 @@ def setup_arg_parser():
default=DEFAULT_MAX_TOKENS,
help="Maximum number of tokens to generate",
)
parser.add_argument(
"--max-tokens-per-sec",
type=int,
default=None,
help="Maximum tokens to generate per second",
)
parser.add_argument(
"--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature"
)
@ -227,6 +233,7 @@ def main():
top_p=args.top_p,
max_kv_size=args.max_kv_size,
prompt_cache=prompt_cache if using_cache else None,
max_tokens_per_sec=args.max_tokens_per_sec,
)
if not args.verbose:
print(response)

View File

@ -123,11 +123,11 @@ def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float)
logits[:, tokens] = selected_logits
return logits
def generate_step(
prompt: mx.array,
model: nn.Module,
temp: float = 0.0,
max_tokens_per_sec: Optional[float] = None, # Add new parameter
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = 20,
top_p: float = 1.0,
@ -145,9 +145,8 @@ def generate_step(
Args:
prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation.
temp (float): The temperature for sampling, if 0 the argmax is used.
Default: ``0``.
repetition_penalty (float, optional): The penalty factor for repeating
temp (float): The temperature for sampling, if 0 the argmax is used. Default: ``0``.
repetition_penalty (float, optional): The penalty factor for repeating
tokens.
repetition_context_size (int, optional): The number of tokens to
consider for repetition penalty. Default: ``20``.
@ -171,7 +170,6 @@ def generate_step(
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
one token and a vector of log probabilities.
"""
def sample(logits: mx.array) -> Tuple[mx.array, float]:
logprobs = logits - mx.logsumexp(logits)
@ -193,16 +191,21 @@ def generate_step(
raise ValueError(
f"repetition_penalty must be a non-negative float, got {repetition_penalty}"
)
if max_tokens_per_sec is not None:
if not isinstance(max_tokens_per_sec, (int, float)) or max_tokens_per_sec <= 0:
raise ValueError(
f"max_tokens_per_sec must be a positive number, got {max_tokens_per_sec}"
)
logits_processor = logits_processor or []
last_token_time = time.perf_counter() # Track time for rate limiting
if repetition_penalty:
def repetition_penalty_processor(tokens, logits):
return apply_repetition_penalty(
logits, tokens[-repetition_context_size:], repetition_penalty
)
logits_processor.append(repetition_penalty_processor)
if logit_bias:
@ -247,33 +250,67 @@ def generate_step(
y, logprobs = _step(y)
mx.async_eval(y, logprobs)
last_target_time = time.perf_counter() # Track when we WANTED the last token
while True:
next_y, next_logprobs = _step(y)
mx.async_eval(next_y, next_logprobs)
if max_tokens_per_sec is not None:
target_time = 1.0 / max_tokens_per_sec
last_target_time += target_time # When we want next token
# Sleep until target time if needed
sleep_time = last_target_time - time.perf_counter()
if sleep_time > 0:
time.sleep(sleep_time)
yield y.item(), logprobs
y, logprobs = next_y, next_logprobs
def stream_generate(
model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: str,
max_tokens: int = 100,
max_tokens_per_sec: Optional[float] = None, # Add parameter
**kwargs,
) -> Union[str, Generator[str, None, None]]:
"""
A generator producing text based on the given prompt from the model.
A generator producing token ids based on the given prompt from the model.
Args:
prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation.
max_tokens (int): The ma
kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details.
temp (float): The temperature for sampling, if 0 the argmax is used.
Default: ``0``.
repetition_penalty (float, optional): The penalty factor for repeating
tokens.
repetition_context_size (int, optional): The number of tokens to
consider for repetition penalty. Default: ``20``.
top_p (float, optional): Nulceus sampling, higher means model considers
more less likely words.
min_p (float, optional): The minimum value (scaled by the top token's
probability) that a token probability must have to be considered.
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
be filtered by min_p sampling.
prefill_step_size (int): Step size for processing the prompt.
max_kv_size (int, optional): Maximum size of the key-value cache. Old
entries (except the first 4 tokens) will be overwritten.
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
provided, the cache will be updated in place.
logit_bias (dictionary, optional): Additive logit bias.
logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional):
A list of functions that take tokens and logits and return the processed
logits. Default: ``None``.
max_tokens_per_sec (float, optional): If set, limits generation speed to approximately
this many tokens per second by adding delays between tokens. Useful for thermal/power
management. Default: None (no limit).
Yields:
Generator[Tuple[mx.array, mx.array]]: A generator producing text.
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
one token and a vector of log probabilities.
"""
if not isinstance(tokenizer, TokenizerWrapper):
tokenizer = TokenizerWrapper(tokenizer)
@ -283,13 +320,11 @@ def stream_generate(
detokenizer.reset()
for n, (token, _) in zip(
range(max_tokens),
generate_step(prompt_tokens, model, **kwargs),
generate_step(prompt_tokens, model, max_tokens_per_sec=max_tokens_per_sec, **kwargs),
):
if token == tokenizer.eos_token_id:
break
detokenizer.add_token(token)
# Yield the last segment if streaming
yield detokenizer.last_segment
detokenizer.finalize()
@ -301,6 +336,7 @@ def generate(
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: str,
max_tokens: int = 100,
max_tokens_per_sec: Optional[float] = None, # Add parameter
verbose: bool = False,
formatter: Optional[Callable] = None,
**kwargs,
@ -313,6 +349,7 @@ def generate(
tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (str): The string prompt.
max_tokens (int): The maximum number of tokens. Default: ``100``.
max_tokens_per_sec (float, optional): If set, limits generation speed to approximately max_tokens_per_sec. May go slightly over this limit.
verbose (bool): If ``True``, print tokens and timing information.
Default: ``False``.
formatter (Optional[Callable]): A function which takes a token and a
@ -335,7 +372,7 @@ def generate(
for n, (token, logprobs) in zip(
range(max_tokens),
generate_step(prompt_tokens, model, **kwargs),
generate_step(prompt_tokens, model, max_tokens_per_sec=max_tokens_per_sec, **kwargs),
):
if n == 0:
prompt_time = time.perf_counter() - tic