diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 26481d6b..0d286c75 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -131,6 +131,18 @@ def setup_arg_parser(): type=int, default=DEFAULT_QUANTIZED_KV_START, ) + parser.add_argument( + "--draft-model", + type=str, + help="A model to be used for speculative decoding.", + default=None, + ) + parser.add_argument( + "--num-draft-tokens", + type=int, + help="Number of tokens to draft when using speculative decoding.", + default=2, + ) return parser @@ -211,11 +223,16 @@ def main(): add_generation_prompt=True, ) prompt = prompt[test_prompt.index("") :] - prompt = tokenizer.encode(prompt, add_special_tokens=False) else: prompt = tokenizer.encode(prompt) + if args.draft_model is not None: + draft_model, draft_tokenizer = load(args.draft_model) + if draft_tokenizer.vocab_size != tokenizer.vocab_size: + raise ValueError("Draft model tokenizer does not match model tokenizer.") + else: + draft_model = None sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep) response = generate( model, @@ -229,6 +246,8 @@ def main(): kv_bits=args.kv_bits, kv_group_size=args.kv_group_size, quantized_kv_start=args.quantized_kv_start, + draft_model=draft_model, + num_draft_tokens=args.num_draft_tokens, ) if not args.verbose: print(response) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 0e06b5a0..2fc0446b 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -2,6 +2,7 @@ import contextlib import copy +import functools import glob import importlib import json @@ -207,12 +208,6 @@ def generate_step( kv_group_size: int = 64, quantized_kv_start: int = 0, prompt_progress_callback: Optional[Callable[int, int]] = None, - temp: Optional[float] = None, - repetition_penalty: Optional[float] = None, - repetition_context_size: Optional[int] = None, - top_p: Optional[float] = None, - min_p: Optional[float] = None, - min_tokens_to_keep: Optional[int] = None, ) -> Generator[Tuple[mx.array, mx.array], None, None]: """ A generator producing token ids based on the given prompt from the model. @@ -256,25 +251,17 @@ def generate_step( elif len(prompt_cache) != len(model.layers): raise ValueError("Wrong number of layers in the prompt cache.") - if temp is not None or top_p is not None or min_tokens_to_keep is not None: - print( - "[Warning] Specifying sampling arguments to ``generate_step`` is " - "deprecated. Pass in a ``sampler`` instead." - ) - if repetition_penalty is not None: - print( - "[Warning] Specifying ``repetition_penalty`` is deprecated. " - "Pass in ``logits_processors`` instead." - ) - - sampler = sampler or make_sampler( - temp or 0.0, top_p or 0.0, min_p or 0.0, min_tokens_to_keep or 1 - ) - logits_processors = logits_processors or make_logits_processors( - None, repetition_penalty, repetition_context_size or 20 - ) prompt_progress_callback = prompt_progress_callback or (lambda *_: None) + quantize_cache_fn = functools.partial( + maybe_quantize_kv_cache, + quantized_kv_start=quantized_kv_start, + kv_group_size=kv_group_size, + kv_bits=kv_bits, + ) + + sampler = sampler or (lambda x: mx.argmax(x, axis=-1)) + def _step(y): with mx.stream(generation_stream): logits = model(y[None], cache=prompt_cache) @@ -287,9 +274,7 @@ def generate_step( for processor in logits_processors: logits = processor(tokens, logits) - maybe_quantize_kv_cache( - prompt_cache, quantized_kv_start, kv_group_size, kv_bits - ) + quantize_cache_fn(prompt_cache) logprobs = logits - mx.logsumexp(logits, keepdims=True) y = sampler(logprobs) @@ -300,9 +285,7 @@ def generate_step( prompt_processed_tokens = 0 while y.size > prefill_step_size: model(y[:prefill_step_size][None], cache=prompt_cache) - maybe_quantize_kv_cache( - prompt_cache, quantized_kv_start, kv_group_size, kv_bits - ) + quantize_cache_fn(prompt_cache) mx.eval([c.state for c in prompt_cache]) prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens) prompt_processed_tokens += prefill_step_size @@ -329,10 +312,162 @@ def generate_step( n += 1 +def speculative_generate_step( + prompt: mx.array, + model: nn.Module, + draft_model: nn.Module, + *, + num_draft_tokens=2, + max_tokens: int = 256, + sampler: Optional[Callable[mx.array, mx.array]] = None, + logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, + prompt_cache: Optional[Any] = None, + prefill_step_size: int = 512, + kv_bits: Optional[int] = None, + kv_group_size: int = 64, + quantized_kv_start: int = 0, +) -> Generator[Tuple[mx.array, mx.array], None, None]: + """ + A generator producing token ids based on the given prompt from the model. + + Args: + prompt (mx.array): The input prompt. + model (nn.Module): The model to use for generation. + draft_model (nn.Module): The draft model for speculative decoding. + num_draft_tokens (int, optional): The number of draft tokens for + speculative decoding. Default: ``2``. + max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite + generator. Default: ``256``. + sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a + token from a vector of log probabilities. Default: ``None``. + logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional): + A list of functions that take tokens and logits and return the processed + logits. Default: ``None``. + prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if + provided, the cache will be updated in place. The cache must be trimmable. + prefill_step_size (int): Step size for processing the prompt. + kv_bits (int, optional): Number of bits to use for KV cache quantization. + None implies no cache quantization. Default: ``None``. + kv_group_size (int): Group size for KV cache quantization. Default: ``64``. + quantized_kv_start (int): Step to begin using a quantized KV cache. + when ``kv_bits`` is non-None. Default: ``0``. + + Yields: + Tuple[mx.array, mx.array]: One token and a vector of log probabilities. + """ + + y = prompt + tokens = None + + # Create the KV cache for generation + if prompt_cache is None: + model_cache = cache.make_prompt_cache(model) + draft_cache = cache.make_prompt_cache(draft_model) + elif len(prompt_cache) != (len(model.layers) + len(draft_model.layers)): + raise ValueError("Wrong number of layers in the prompt cache.") + else: + model_cache = prompt_cache[: len(model.layers)] + draft_cache = prompt_cache[len(model.layers) :] + + sampler = sampler or (lambda x: mx.argmax(x, axis=-1)) + + quantize_cache_fn = functools.partial( + maybe_quantize_kv_cache, + quantized_kv_start=quantized_kv_start, + kv_group_size=kv_group_size, + kv_bits=kv_bits, + ) + + def _step(model, cache, y, n_predict=1): + with mx.stream(generation_stream): + logits = model(y[None], cache=cache) + logits = logits[:, -n_predict:, :] + + quantize_cache_fn(cache) + + logprobs = logits - mx.logsumexp(logits, keepdims=True) + y = sampler(logprobs).squeeze(0) + return y, logprobs.squeeze(0) + + def _prefill(model, cache, y): + while y.size > prefill_step_size: + model(y[:prefill_step_size][None], cache=cache) + quantize_cache_fn(cache) + mx.eval([c.state for c in cache]) + y = y[prefill_step_size:] + mx.metal.clear_cache() + return y + + def _rewind_cache(num_draft, num_accept): + cache.trim_prompt_cache(model_cache, num_draft - num_accept) + cache.trim_prompt_cache(draft_cache, max(num_draft - num_accept - 1, 0)) + + def _draft_generate(y, num_draft): + if num_draft == 0: + return mx.array([], mx.uint32) + ys = [] + for _ in range(num_draft): + y, _ = _step(draft_model, draft_cache, y) + mx.async_eval(y) + ys.append(y) + return mx.concatenate(ys) + + with mx.stream(generation_stream): + draft_y = _prefill(draft_model, draft_cache, y) + y = _prefill(model, model_cache, y) + + ntoks = 0 + # Set these so the finally block doesn't raise + num_draft = 0 + n = 0 + try: + while True: + num_draft = min(max_tokens - ntoks, num_draft_tokens) + draft_tokens = _draft_generate(draft_y, num_draft) + y = mx.concatenate([y, draft_tokens]) + + tokens, logprobs = _step(model, model_cache, y, num_draft + 1) + mx.eval(tokens, draft_tokens) + draft_tokens = draft_tokens.tolist() + tokens = tokens.tolist() + n = 0 + while n < num_draft: + tn, dtn, lpn = tokens[n], draft_tokens[n], logprobs[n] + if tn != dtn: + break + n += 1 + ntoks += 1 + yield tn, lpn + if ntoks == max_tokens: + break + if ntoks < max_tokens: + ntoks += 1 + yield tokens[n], logprobs[n] + + if ntoks == max_tokens: + break + + y = mx.array([tokens[n]], mx.uint32) + draft_y = y + + # If we accpeted all the draft tokens, include the last + # draft token in the next draft step since it hasn't been + # processed yet by the draft model + if n == num_draft: + draft_y = mx.concatenate( + [mx.array(draft_tokens[-1:], mx.uint32), draft_y] + ) + + _rewind_cache(num_draft, n) + finally: + _rewind_cache(num_draft, n) + + def stream_generate( model: nn.Module, tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], prompt: Union[str, mx.array, List[int]], + draft_model: Optional[nn.Module] = None, **kwargs, ) -> Generator[GenerationResponse, None, None]: """ @@ -341,7 +476,11 @@ def stream_generate( Args: model (nn.Module): The model to use for generation. tokenizer (PreTrainedTokenizer): The tokenizer. - prompt (Union[str, mx.array, List[int]]): The input prompt string or integer tokens. + prompt (Union[str, mx.array, List[int]]): The input prompt string or + integer tokens. + draft_model (Optional[nn.Module]): An optional draft model. If provided + then speculative decoding is used. The draft model must use the same + tokenizer as the main model. Default: ``None``. kwargs: The remaining options get passed to :func:`generate_step`. See :func:`generate_step` for more details. @@ -363,10 +502,18 @@ def stream_generate( detokenizer = tokenizer.detokenizer + if draft_model is None: + kwargs.pop("num_draft_tokens", None) + token_generator = generate_step(prompt, model, **kwargs) + else: + kwargs.pop("max_kv_size", None) + token_generator = speculative_generate_step( + prompt, model, draft_model, **kwargs + ) with wired_limit(model, [generation_stream]): detokenizer.reset() tic = time.perf_counter() - for n, (token, logprobs) in enumerate(generate_step(prompt, model, **kwargs)): + for n, (token, logprobs) in enumerate(token_generator): if n == 0: prompt_time = time.perf_counter() - tic prompt_tps = prompt.size / prompt_time