diff --git a/llms/mlx_lm/_version.py b/llms/mlx_lm/_version.py index 343e0016..0f885fba 100644 --- a/llms/mlx_lm/_version.py +++ b/llms/mlx_lm/_version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.20.1" +__version__ = "0.20.2" diff --git a/llms/mlx_lm/cache_prompt.py b/llms/mlx_lm/cache_prompt.py index 987b640d..9d7d1603 100644 --- a/llms/mlx_lm/cache_prompt.py +++ b/llms/mlx_lm/cache_prompt.py @@ -8,7 +8,7 @@ import time import mlx.core as mx from .models.cache import make_prompt_cache, save_prompt_cache -from .utils import load, maybe_quantize_kv_cache +from .utils import generate_step, load DEFAULT_QUANTIZED_KV_START = 5000 @@ -50,12 +50,6 @@ def setup_arg_parser(): action="store_true", help="Use the default chat template", ) - parser.add_argument( - "--cache-limit-gb", - type=int, - default=None, - help="Set the MLX cache limit in GB", - ) parser.add_argument( "--max-kv-size", type=int, @@ -99,9 +93,6 @@ def main(): parser = setup_arg_parser() args = parser.parse_args() - if args.cache_limit_gb is not None: - mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024) - # Building tokenizer_config tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None} if args.eos_token is not None: @@ -144,26 +135,28 @@ def main(): y = mx.array(tokenizer.encode(prompt)) # Process the prompt - processed = 0 - step_size = 512 start = time.time() max_msg_len = 0 - while y.size > 0: - model(y[:step_size][None], cache=cache) - mx.eval([c.state for c in cache]) - mx.metal.clear_cache() - processed += min(y.size, step_size) - y = y[step_size:] + def callback(processed, total_tokens): current = time.time() speed = processed / (current - start) msg = f"\rProcessed {processed:6d} tokens ({speed:6.2f} tok/s)" + nonlocal max_msg_len max_msg_len = max(max_msg_len, len(msg)) print(msg + " " * (max_msg_len - len(msg)), end="", flush=True) - maybe_quantize_kv_cache( - cache, args.quantized_kv_start, args.kv_group_size, args.kv_bits - ) + for _ in generate_step( + y, + model, + max_tokens=0, + prompt_cache=cache, + kv_bits=args.kv_bits, + kv_group_size=args.kv_group_size, + quantized_kv_start=args.quantized_kv_start, + prompt_progress_callback=callback, + ): + pass print() print(f"Peak memory: {mx.metal.get_peak_memory() / 1e9:.3f} GB") diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 9e96fbdc..0c1b4acd 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -77,7 +77,7 @@ def setup_arg_parser(): ) parser.add_argument( "--min-tokens-to-keep", - type=float, + type=int, default=DEFAULT_MIN_TOKENS_TO_KEEP, help="Minimum tokens to keep for min-p sampling.", ) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index f439ca99..86b786ce 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -183,6 +183,7 @@ def generate_step( prompt: mx.array, model: nn.Module, *, + max_tokens: int = 256, sampler: Optional[Callable[mx.array, mx.array]] = None, logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, max_kv_size: Optional[int] = None, @@ -191,6 +192,7 @@ def generate_step( kv_bits: Optional[int] = None, kv_group_size: int = 64, quantized_kv_start: int = 0, + prompt_progress_callback: Optional[Callable[int, int]] = None, temp: Optional[float] = None, repetition_penalty: Optional[float] = None, repetition_context_size: Optional[int] = None, @@ -204,21 +206,25 @@ def generate_step( Args: prompt (mx.array): The input prompt. model (nn.Module): The model to use for generation. - prefill_step_size (int): Step size for processing the prompt. - max_kv_size (int, optional): Maximum size of the key-value cache. Old - entries (except the first 4 tokens) will be overwritten. - prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if - provided, the cache will be updated in place. + max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite + generator. Default: ``256``. sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a token from a vector of log probabilities. Default: ``None``. logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional): A list of functions that take tokens and logits and return the processed logits. Default: ``None``. + max_kv_size (int, optional): Maximum size of the key-value cache. Old + entries (except the first 4 tokens) will be overwritten. + prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if + provided, the cache will be updated in place. + prefill_step_size (int): Step size for processing the prompt. kv_bits (int, optional): Number of bits to use for KV cache quantization. None implies no cache quantization. Default: ``None``. kv_group_size (int): Group size for KV cache quantization. Default: ``64``. quantized_kv_start (int): Step to begin using a quantized KV cache. when ``kv_bits`` is non-None. Default: ``0``. + prompt_prorgress_callback (Callable[int, int]): A call-back which takes the + prompt tokens processed so far and the total number of prompt tokens. Yields: Tuple[mx.array, mx.array]: One token and a vector of log probabilities. @@ -253,6 +259,7 @@ def generate_step( logits_processors = logits_processors or make_logits_processors( None, repetition_penalty, repetition_context_size or 20 ) + prompt_progress_callback = prompt_progress_callback or (lambda *_: None) def _step(y): with mx.stream(generation_stream): @@ -275,9 +282,13 @@ def generate_step( return y, logprobs.squeeze(0) with mx.stream(generation_stream): + total_prompt_tokens = y.size + prompt_processed_tokens = 0 while y.size > prefill_step_size: model(y[:prefill_step_size][None], cache=prompt_cache) mx.eval([c.state for c in prompt_cache]) + prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens) + prompt_processed_tokens += prefill_step_size y = y[prefill_step_size:] mx.metal.clear_cache() @@ -286,20 +297,25 @@ def generate_step( mx.async_eval(y, logprobs) n = 0 while True: - next_y, next_logprobs = _step(y) - mx.async_eval(next_y, next_logprobs) + if n != max_tokens: + next_y, next_logprobs = _step(y) + mx.async_eval(next_y, next_logprobs) + if n == 0: + mx.eval(y) + prompt_progress_callback(total_prompt_tokens, total_prompt_tokens) + if n == max_tokens: + break yield y.item(), logprobs if n % 256 == 0: mx.metal.clear_cache() - n += 1 y, logprobs = next_y, next_logprobs + n += 1 def stream_generate( model: nn.Module, tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], prompt: Union[str, mx.array, List[int]], - max_tokens: int = 100, **kwargs, ) -> Generator[GenerationResponse, None, None]: """ @@ -309,7 +325,6 @@ def stream_generate( model (nn.Module): The model to use for generation. tokenizer (PreTrainedTokenizer): The tokenizer. prompt (Union[str, mx.array, List[int]]): The input prompt string or integer tokens. - max_tokens (int): The maximum number of tokens. Default: ``100``. kwargs: The remaining options get passed to :func:`generate_step`. See :func:`generate_step` for more details. @@ -330,10 +345,7 @@ def stream_generate( with wired_limit(model, [generation_stream]): detokenizer.reset() tic = time.perf_counter() - for n, (token, logprobs) in zip( - range(max_tokens), - generate_step(prompt, model, **kwargs), - ): + for n, (token, logprobs) in enumerate(generate_step(prompt, model, **kwargs)): if n == 0: prompt_time = time.perf_counter() - tic prompt_tps = prompt.size / prompt_time @@ -343,9 +355,6 @@ def stream_generate( detokenizer.add_token(token) - if n == (max_tokens - 1): - break - yield GenerationResponse( text=detokenizer.last_segment, token=token, @@ -385,7 +394,6 @@ def generate( model (nn.Module): The language model. tokenizer (PreTrainedTokenizer): The tokenizer. prompt (str): The string prompt. - max_tokens (int): The maximum number of tokens. Default: ``100``. verbose (bool): If ``True``, print tokens and timing information. Default: ``False``. kwargs: The remaining options get passed to :func:`stream_generate`. diff --git a/llms/tests/test_prompt_cache.py b/llms/tests/test_prompt_cache.py index 0867ab56..de5694d5 100644 --- a/llms/tests/test_prompt_cache.py +++ b/llms/tests/test_prompt_cache.py @@ -121,21 +121,20 @@ class TestPromptCache(unittest.TestCase): def test_cache_with_generate(self): model, tokenizer = load(HF_MODEL_PATH) prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0] - results = zip(range(4), generate_step(prompt, model)) - toks, all_logits = zip(*(r[1] for r in results)) + results = list(generate_step(prompt, model, max_tokens=4)) + toks, all_logits = zip(*results) prompt_cache = make_prompt_cache(model) i = 0 - for _, (tok, logits) in zip( - range(2), generate_step(prompt, model, prompt_cache=prompt_cache) + for tok, logits in generate_step( + prompt, model, prompt_cache=prompt_cache, max_tokens=2 ): self.assertEqual(tok, toks[i]) self.assertTrue(mx.allclose(logits, all_logits[i])) i += 1 - for _, (tok, logits) in zip( - range(1), - generate_step(mx.array([toks[i]]), model, prompt_cache=prompt_cache), + for tok, logits in generate_step( + mx.array([toks[i]]), model, prompt_cache=prompt_cache, max_tokens=1 ): i += 1 self.assertEqual(tok, toks[i])