mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 10:56:38 +08:00
add max token limit
This commit is contained in:
parent
8fe9539af7
commit
7e4413b1dd
@ -29,7 +29,16 @@ def setup_arg_parser():
|
|||||||
help="Optional path for the trained adapter weights and config.",
|
help="Optional path for the trained adapter weights and config.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature"
|
"--max-tokens-per-sec",
|
||||||
|
type=int,
|
||||||
|
help="Maximum tokens to generate per second.",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-tokens-per-sec",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Maximum tokens to generate per second",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p"
|
"--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p"
|
||||||
@ -56,7 +65,7 @@ def main():
|
|||||||
tokenizer_config={"trust_remote_code": True},
|
tokenizer_config={"trust_remote_code": True},
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"[INFO] Starting chat sessiong with {args.model}. To exit, enter 'q'.")
|
print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.")
|
||||||
prompt_cache = make_prompt_cache(model, args.max_kv_size)
|
prompt_cache = make_prompt_cache(model, args.max_kv_size)
|
||||||
while True:
|
while True:
|
||||||
query = input(">> ")
|
query = input(">> ")
|
||||||
@ -72,7 +81,9 @@ def main():
|
|||||||
prompt,
|
prompt,
|
||||||
temp=args.temp,
|
temp=args.temp,
|
||||||
top_p=args.top_p,
|
top_p=args.top_p,
|
||||||
|
max_tokens_per_sec=args.max_tokens_per_sec,
|
||||||
prompt_cache=prompt_cache,
|
prompt_cache=prompt_cache,
|
||||||
|
max_tokens=4096 # Ensure this is set to a reasonable limit
|
||||||
):
|
):
|
||||||
print(response, flush=True, end="")
|
print(response, flush=True, end="")
|
||||||
print()
|
print()
|
||||||
|
@ -9,8 +9,8 @@ import mlx.core as mx
|
|||||||
from .models.cache import load_prompt_cache
|
from .models.cache import load_prompt_cache
|
||||||
from .utils import generate, load
|
from .utils import generate, load
|
||||||
|
|
||||||
DEFAULT_PROMPT = "hello"
|
DEFAULT_PROMPT = "Tell me a story!"
|
||||||
DEFAULT_MAX_TOKENS = 100
|
DEFAULT_MAX_TOKENS = 1000
|
||||||
DEFAULT_TEMP = 0.0
|
DEFAULT_TEMP = 0.0
|
||||||
DEFAULT_TOP_P = 1.0
|
DEFAULT_TOP_P = 1.0
|
||||||
DEFAULT_SEED = 0
|
DEFAULT_SEED = 0
|
||||||
@ -61,6 +61,12 @@ def setup_arg_parser():
|
|||||||
default=DEFAULT_MAX_TOKENS,
|
default=DEFAULT_MAX_TOKENS,
|
||||||
help="Maximum number of tokens to generate",
|
help="Maximum number of tokens to generate",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-tokens-per-sec",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Maximum tokens to generate per second",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature"
|
"--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature"
|
||||||
)
|
)
|
||||||
@ -227,6 +233,7 @@ def main():
|
|||||||
top_p=args.top_p,
|
top_p=args.top_p,
|
||||||
max_kv_size=args.max_kv_size,
|
max_kv_size=args.max_kv_size,
|
||||||
prompt_cache=prompt_cache if using_cache else None,
|
prompt_cache=prompt_cache if using_cache else None,
|
||||||
|
max_tokens_per_sec=args.max_tokens_per_sec,
|
||||||
)
|
)
|
||||||
if not args.verbose:
|
if not args.verbose:
|
||||||
print(response)
|
print(response)
|
||||||
|
@ -123,11 +123,11 @@ def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float)
|
|||||||
logits[:, tokens] = selected_logits
|
logits[:, tokens] = selected_logits
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
def generate_step(
|
def generate_step(
|
||||||
prompt: mx.array,
|
prompt: mx.array,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
temp: float = 0.0,
|
temp: float = 0.0,
|
||||||
|
max_tokens_per_sec: Optional[float] = None, # Add new parameter
|
||||||
repetition_penalty: Optional[float] = None,
|
repetition_penalty: Optional[float] = None,
|
||||||
repetition_context_size: Optional[int] = 20,
|
repetition_context_size: Optional[int] = 20,
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
@ -145,8 +145,7 @@ def generate_step(
|
|||||||
Args:
|
Args:
|
||||||
prompt (mx.array): The input prompt.
|
prompt (mx.array): The input prompt.
|
||||||
model (nn.Module): The model to use for generation.
|
model (nn.Module): The model to use for generation.
|
||||||
temp (float): The temperature for sampling, if 0 the argmax is used.
|
temp (float): The temperature for sampling, if 0 the argmax is used. Default: ``0``.
|
||||||
Default: ``0``.
|
|
||||||
repetition_penalty (float, optional): The penalty factor for repeating
|
repetition_penalty (float, optional): The penalty factor for repeating
|
||||||
tokens.
|
tokens.
|
||||||
repetition_context_size (int, optional): The number of tokens to
|
repetition_context_size (int, optional): The number of tokens to
|
||||||
@ -171,7 +170,6 @@ def generate_step(
|
|||||||
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
|
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
|
||||||
one token and a vector of log probabilities.
|
one token and a vector of log probabilities.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def sample(logits: mx.array) -> Tuple[mx.array, float]:
|
def sample(logits: mx.array) -> Tuple[mx.array, float]:
|
||||||
logprobs = logits - mx.logsumexp(logits)
|
logprobs = logits - mx.logsumexp(logits)
|
||||||
|
|
||||||
@ -194,15 +192,20 @@ def generate_step(
|
|||||||
f"repetition_penalty must be a non-negative float, got {repetition_penalty}"
|
f"repetition_penalty must be a non-negative float, got {repetition_penalty}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if max_tokens_per_sec is not None:
|
||||||
|
if not isinstance(max_tokens_per_sec, (int, float)) or max_tokens_per_sec <= 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"max_tokens_per_sec must be a positive number, got {max_tokens_per_sec}"
|
||||||
|
)
|
||||||
|
|
||||||
logits_processor = logits_processor or []
|
logits_processor = logits_processor or []
|
||||||
|
last_token_time = time.perf_counter() # Track time for rate limiting
|
||||||
|
|
||||||
if repetition_penalty:
|
if repetition_penalty:
|
||||||
|
|
||||||
def repetition_penalty_processor(tokens, logits):
|
def repetition_penalty_processor(tokens, logits):
|
||||||
return apply_repetition_penalty(
|
return apply_repetition_penalty(
|
||||||
logits, tokens[-repetition_context_size:], repetition_penalty
|
logits, tokens[-repetition_context_size:], repetition_penalty
|
||||||
)
|
)
|
||||||
|
|
||||||
logits_processor.append(repetition_penalty_processor)
|
logits_processor.append(repetition_penalty_processor)
|
||||||
|
|
||||||
if logit_bias:
|
if logit_bias:
|
||||||
@ -247,33 +250,67 @@ def generate_step(
|
|||||||
y, logprobs = _step(y)
|
y, logprobs = _step(y)
|
||||||
|
|
||||||
mx.async_eval(y, logprobs)
|
mx.async_eval(y, logprobs)
|
||||||
|
last_target_time = time.perf_counter() # Track when we WANTED the last token
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
next_y, next_logprobs = _step(y)
|
next_y, next_logprobs = _step(y)
|
||||||
mx.async_eval(next_y, next_logprobs)
|
mx.async_eval(next_y, next_logprobs)
|
||||||
|
|
||||||
|
if max_tokens_per_sec is not None:
|
||||||
|
target_time = 1.0 / max_tokens_per_sec
|
||||||
|
last_target_time += target_time # When we want next token
|
||||||
|
|
||||||
|
# Sleep until target time if needed
|
||||||
|
sleep_time = last_target_time - time.perf_counter()
|
||||||
|
if sleep_time > 0:
|
||||||
|
time.sleep(sleep_time)
|
||||||
|
|
||||||
yield y.item(), logprobs
|
yield y.item(), logprobs
|
||||||
y, logprobs = next_y, next_logprobs
|
y, logprobs = next_y, next_logprobs
|
||||||
|
|
||||||
|
|
||||||
def stream_generate(
|
def stream_generate(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
||||||
prompt: str,
|
prompt: str,
|
||||||
max_tokens: int = 100,
|
max_tokens: int = 100,
|
||||||
|
max_tokens_per_sec: Optional[float] = None, # Add parameter
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[str, Generator[str, None, None]]:
|
) -> Union[str, Generator[str, None, None]]:
|
||||||
"""
|
"""
|
||||||
A generator producing text based on the given prompt from the model.
|
A generator producing token ids based on the given prompt from the model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompt (mx.array): The input prompt.
|
prompt (mx.array): The input prompt.
|
||||||
model (nn.Module): The model to use for generation.
|
model (nn.Module): The model to use for generation.
|
||||||
max_tokens (int): The ma
|
temp (float): The temperature for sampling, if 0 the argmax is used.
|
||||||
kwargs: The remaining options get passed to :func:`generate_step`.
|
Default: ``0``.
|
||||||
See :func:`generate_step` for more details.
|
repetition_penalty (float, optional): The penalty factor for repeating
|
||||||
|
tokens.
|
||||||
|
repetition_context_size (int, optional): The number of tokens to
|
||||||
|
consider for repetition penalty. Default: ``20``.
|
||||||
|
top_p (float, optional): Nulceus sampling, higher means model considers
|
||||||
|
more less likely words.
|
||||||
|
min_p (float, optional): The minimum value (scaled by the top token's
|
||||||
|
probability) that a token probability must have to be considered.
|
||||||
|
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
|
||||||
|
be filtered by min_p sampling.
|
||||||
|
prefill_step_size (int): Step size for processing the prompt.
|
||||||
|
max_kv_size (int, optional): Maximum size of the key-value cache. Old
|
||||||
|
entries (except the first 4 tokens) will be overwritten.
|
||||||
|
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
|
||||||
|
provided, the cache will be updated in place.
|
||||||
|
logit_bias (dictionary, optional): Additive logit bias.
|
||||||
|
logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
||||||
|
A list of functions that take tokens and logits and return the processed
|
||||||
|
logits. Default: ``None``.
|
||||||
|
max_tokens_per_sec (float, optional): If set, limits generation speed to approximately
|
||||||
|
this many tokens per second by adding delays between tokens. Useful for thermal/power
|
||||||
|
management. Default: None (no limit).
|
||||||
Yields:
|
Yields:
|
||||||
Generator[Tuple[mx.array, mx.array]]: A generator producing text.
|
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
|
||||||
|
one token and a vector of log probabilities.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not isinstance(tokenizer, TokenizerWrapper):
|
if not isinstance(tokenizer, TokenizerWrapper):
|
||||||
tokenizer = TokenizerWrapper(tokenizer)
|
tokenizer = TokenizerWrapper(tokenizer)
|
||||||
|
|
||||||
@ -283,13 +320,11 @@ def stream_generate(
|
|||||||
detokenizer.reset()
|
detokenizer.reset()
|
||||||
for n, (token, _) in zip(
|
for n, (token, _) in zip(
|
||||||
range(max_tokens),
|
range(max_tokens),
|
||||||
generate_step(prompt_tokens, model, **kwargs),
|
generate_step(prompt_tokens, model, max_tokens_per_sec=max_tokens_per_sec, **kwargs),
|
||||||
):
|
):
|
||||||
if token == tokenizer.eos_token_id:
|
if token == tokenizer.eos_token_id:
|
||||||
break
|
break
|
||||||
detokenizer.add_token(token)
|
detokenizer.add_token(token)
|
||||||
|
|
||||||
# Yield the last segment if streaming
|
|
||||||
yield detokenizer.last_segment
|
yield detokenizer.last_segment
|
||||||
|
|
||||||
detokenizer.finalize()
|
detokenizer.finalize()
|
||||||
@ -301,6 +336,7 @@ def generate(
|
|||||||
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
||||||
prompt: str,
|
prompt: str,
|
||||||
max_tokens: int = 100,
|
max_tokens: int = 100,
|
||||||
|
max_tokens_per_sec: Optional[float] = None, # Add parameter
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
formatter: Optional[Callable] = None,
|
formatter: Optional[Callable] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -313,6 +349,7 @@ def generate(
|
|||||||
tokenizer (PreTrainedTokenizer): The tokenizer.
|
tokenizer (PreTrainedTokenizer): The tokenizer.
|
||||||
prompt (str): The string prompt.
|
prompt (str): The string prompt.
|
||||||
max_tokens (int): The maximum number of tokens. Default: ``100``.
|
max_tokens (int): The maximum number of tokens. Default: ``100``.
|
||||||
|
max_tokens_per_sec (float, optional): If set, limits generation speed to approximately max_tokens_per_sec. May go slightly over this limit.
|
||||||
verbose (bool): If ``True``, print tokens and timing information.
|
verbose (bool): If ``True``, print tokens and timing information.
|
||||||
Default: ``False``.
|
Default: ``False``.
|
||||||
formatter (Optional[Callable]): A function which takes a token and a
|
formatter (Optional[Callable]): A function which takes a token and a
|
||||||
@ -335,7 +372,7 @@ def generate(
|
|||||||
|
|
||||||
for n, (token, logprobs) in zip(
|
for n, (token, logprobs) in zip(
|
||||||
range(max_tokens),
|
range(max_tokens),
|
||||||
generate_step(prompt_tokens, model, **kwargs),
|
generate_step(prompt_tokens, model, max_tokens_per_sec=max_tokens_per_sec, **kwargs),
|
||||||
):
|
):
|
||||||
if n == 0:
|
if n == 0:
|
||||||
prompt_time = time.perf_counter() - tic
|
prompt_time = time.perf_counter() - tic
|
||||||
|
Loading…
Reference in New Issue
Block a user