diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 0e06b5a0..102512c1 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -2,6 +2,7 @@ import contextlib import copy +import functools import glob import importlib import json @@ -207,12 +208,6 @@ def generate_step( kv_group_size: int = 64, quantized_kv_start: int = 0, prompt_progress_callback: Optional[Callable[int, int]] = None, - temp: Optional[float] = None, - repetition_penalty: Optional[float] = None, - repetition_context_size: Optional[int] = None, - top_p: Optional[float] = None, - min_p: Optional[float] = None, - min_tokens_to_keep: Optional[int] = None, ) -> Generator[Tuple[mx.array, mx.array], None, None]: """ A generator producing token ids based on the given prompt from the model. @@ -256,25 +251,15 @@ def generate_step( elif len(prompt_cache) != len(model.layers): raise ValueError("Wrong number of layers in the prompt cache.") - if temp is not None or top_p is not None or min_tokens_to_keep is not None: - print( - "[Warning] Specifying sampling arguments to ``generate_step`` is " - "deprecated. Pass in a ``sampler`` instead." - ) - if repetition_penalty is not None: - print( - "[Warning] Specifying ``repetition_penalty`` is deprecated. " - "Pass in ``logits_processors`` instead." - ) - - sampler = sampler or make_sampler( - temp or 0.0, top_p or 0.0, min_p or 0.0, min_tokens_to_keep or 1 - ) - logits_processors = logits_processors or make_logits_processors( - None, repetition_penalty, repetition_context_size or 20 - ) prompt_progress_callback = prompt_progress_callback or (lambda *_: None) + quantize_cache_fn = functools.partial( + maybe_quantize_kv_cache, + quantized_kv_start=quantized_kv_start, + kv_group_size=kv_group_size, + kv_bits=kv_bits, + ) + def _step(y): with mx.stream(generation_stream): logits = model(y[None], cache=prompt_cache) @@ -287,9 +272,7 @@ def generate_step( for processor in logits_processors: logits = processor(tokens, logits) - maybe_quantize_kv_cache( - prompt_cache, quantized_kv_start, kv_group_size, kv_bits - ) + quantize_cache_fn(prompt_cache) logprobs = logits - mx.logsumexp(logits, keepdims=True) y = sampler(logprobs) @@ -300,9 +283,7 @@ def generate_step( prompt_processed_tokens = 0 while y.size > prefill_step_size: model(y[:prefill_step_size][None], cache=prompt_cache) - maybe_quantize_kv_cache( - prompt_cache, quantized_kv_start, kv_group_size, kv_bits - ) + quantize_cache_fn(prompt_cache) mx.eval([c.state for c in prompt_cache]) prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens) prompt_processed_tokens += prefill_step_size @@ -329,10 +310,143 @@ def generate_step( n += 1 +def speculative_generate_step( + prompt: mx.array, + model: nn.Module, + draft_model: nn.Module, + *, + num_draft_tokens=2, + max_tokens: int = 256, + prompt_cache: Optional[Any] = None, + prefill_step_size: int = 512, + kv_bits: Optional[int] = None, + kv_group_size: int = 64, + quantized_kv_start: int = 0, +) -> Generator[Tuple[mx.array, mx.array], None, None]: + """ + A generator producing token ids based on the given prompt from the model. + + Args: + prompt (mx.array): The input prompt. + model (nn.Module): The model to use for generation. + draft_model (nn.Module): The draft model for speculative decoding. + num_draft_tokens (int, optional): The number of draft tokens for + speculative decoding. Default: ``2``. + max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite + generator. Default: ``256``. + prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if + provided, the cache will be updated in place. The cache must be trimmable. + prefill_step_size (int): Step size for processing the prompt. + kv_bits (int, optional): Number of bits to use for KV cache quantization. + None implies no cache quantization. Default: ``None``. + kv_group_size (int): Group size for KV cache quantization. Default: ``64``. + quantized_kv_start (int): Step to begin using a quantized KV cache. + when ``kv_bits`` is non-None. Default: ``0``. + + Yields: + Tuple[mx.array, mx.array]: One token and a vector of log probabilities. + """ + + y = prompt + tokens = None + + # Create the KV cache for generation + if prompt_cache is None: + model_cache = cache.make_prompt_cache(model) + draft_cache = cache.make_prompt_cache(draft_model) + elif len(prompt_cache) != (len(model.layers) + len(draft_model.layers)): + raise ValueError("Wrong number of layers in the prompt cache.") + else: + model_cache = prompt_cache[: len(model.layers)] + draft_cache = prompt_cache[len(model.layers) :] + + def _step(model, cache, y, n_predict=1): + with mx.stream(generation_stream): + logits = model(y[None], cache=cache) + logits = logits[:, -n_predict:, :] + + quantize_cache_fn(cache) + + logprobs = logits - mx.logsumexp(logits, keepdims=True) + y = mx.argmax(logprobs, axis=-1).squeeze(0) + return y, logprobs.squeeze(0) + + def _prefill(model, cache, y): + while y.size > prefill_step_size: + model(y[:prefill_step_size][None], cache=cache) + quantize_cache_fn(cache) + mx.eval([c.state for c in cache]) + y = y[prefill_step_size:] + mx.metal.clear_cache() + return y + + def _rewind_cache(num_draft, num_accept): + cache.trim_prompt_cache(model_cache, num_draft - num_accept) + cache.trim_prompt_cache(draft_cache, max(num_draft - num_accept - 1, 0)) + + def _draft_generate(y, num_draft): + if num_draft == 0: + return mx.array([], mx.uint32) + ys = [] + for _ in range(num_draft): + y, _ = _step(draft_model, draft_cache, y) + mx.async_eval(y) + ys.append(y) + return mx.concatenate(ys) + + with mx.stream(generation_stream): + draft_y = _prefill(draft_model, draft_cache, y) + y = _prefill(model, model_cache, y) + + ntoks = 0 + try: + while True: + num_draft = min(max_tokens - ntoks, num_draft_tokens) + draft_tokens = _draft_generate(draft_y, num_draft) + y = mx.concatenate([y, draft_tokens]) + + tokens, logprobs = _step(model, model_cache, y, num_draft + 1) + mx.eval(tokens, draft_tokens) + draft_tokens = draft_tokens.tolist() + tokens = tokens.tolist() + n = 0 + while n < num_draft: + tn, dtn, lpn = tokens[n], draft_tokens[n], logprobs[n] + if tn != dtn: + break + n += 1 + ntoks += 1 + yield tn, lpn + if ntoks == max_tokens: + break + if ntoks < max_tokens: + ntoks += 1 + yield tokens[n], logprobs[n] + + if ntoks == max_tokens: + break + + y = mx.array([tokens[n]], mx.uint32) + draft_y = y + + # If we accpeted all the draft tokens, include the last + # draft token in the next draft step since it hasn't been + # processed yet by the draft model + if n == num_draft: + draft_y = mx.concatenate( + [mx.array(draft_tokens[-1:], mx.uint32), draft_y] + ) + + _rewind_cache(num_draft, n) + finally: + _rewind_cache(num_draft, n) + + def stream_generate( model: nn.Module, tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], prompt: Union[str, mx.array, List[int]], + draft_model: Optional[nn.Module] = None, **kwargs, ) -> Generator[GenerationResponse, None, None]: """ @@ -341,7 +455,11 @@ def stream_generate( Args: model (nn.Module): The model to use for generation. tokenizer (PreTrainedTokenizer): The tokenizer. - prompt (Union[str, mx.array, List[int]]): The input prompt string or integer tokens. + prompt (Union[str, mx.array, List[int]]): The input prompt string or + integer tokens. + draft_model (Optional[nn.Module]): An optional draft model. If provided + then speculative decoding is used. The draft model must use the same + tokenizer as the main model. Default: ``None``. kwargs: The remaining options get passed to :func:`generate_step`. See :func:`generate_step` for more details. @@ -363,10 +481,16 @@ def stream_generate( detokenizer = tokenizer.detokenizer + if draft_model is None: + token_generator = generate_step(prompt, model, **kwargs) + else: + token_generator = speculative_generate_step( + prompt, model, draft_model, **kwargs + ) with wired_limit(model, [generation_stream]): detokenizer.reset() tic = time.perf_counter() - for n, (token, logprobs) in enumerate(generate_step(prompt, model, **kwargs)): + for n, (token, logprobs) in enumerate(token_generator): if n == 0: prompt_time = time.perf_counter() - tic prompt_tps = prompt.size / prompt_time