mirror of
				https://github.com/ml-explore/mlx-examples.git
				synced 2025-10-31 19:18:09 +08:00 
			
		
		
		
	Add a speculative decoding generator (#1155)
* add a speculative decoding generator * fix * fixes * optional kwarg pop
This commit is contained in:
		| @@ -131,6 +131,18 @@ def setup_arg_parser(): | ||||
|         type=int, | ||||
|         default=DEFAULT_QUANTIZED_KV_START, | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--draft-model", | ||||
|         type=str, | ||||
|         help="A model to be used for speculative decoding.", | ||||
|         default=None, | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--num-draft-tokens", | ||||
|         type=int, | ||||
|         help="Number of tokens to draft when using speculative decoding.", | ||||
|         default=2, | ||||
|     ) | ||||
|     return parser | ||||
|  | ||||
|  | ||||
| @@ -211,11 +223,16 @@ def main(): | ||||
|                 add_generation_prompt=True, | ||||
|             ) | ||||
|             prompt = prompt[test_prompt.index("<query>") :] | ||||
|  | ||||
|         prompt = tokenizer.encode(prompt, add_special_tokens=False) | ||||
|     else: | ||||
|         prompt = tokenizer.encode(prompt) | ||||
|  | ||||
|     if args.draft_model is not None: | ||||
|         draft_model, draft_tokenizer = load(args.draft_model) | ||||
|         if draft_tokenizer.vocab_size != tokenizer.vocab_size: | ||||
|             raise ValueError("Draft model tokenizer does not match model tokenizer.") | ||||
|     else: | ||||
|         draft_model = None | ||||
|     sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep) | ||||
|     response = generate( | ||||
|         model, | ||||
| @@ -229,6 +246,8 @@ def main(): | ||||
|         kv_bits=args.kv_bits, | ||||
|         kv_group_size=args.kv_group_size, | ||||
|         quantized_kv_start=args.quantized_kv_start, | ||||
|         draft_model=draft_model, | ||||
|         num_draft_tokens=args.num_draft_tokens, | ||||
|     ) | ||||
|     if not args.verbose: | ||||
|         print(response) | ||||
|   | ||||
| @@ -2,6 +2,7 @@ | ||||
|  | ||||
| import contextlib | ||||
| import copy | ||||
| import functools | ||||
| import glob | ||||
| import importlib | ||||
| import json | ||||
| @@ -207,12 +208,6 @@ def generate_step( | ||||
|     kv_group_size: int = 64, | ||||
|     quantized_kv_start: int = 0, | ||||
|     prompt_progress_callback: Optional[Callable[int, int]] = None, | ||||
|     temp: Optional[float] = None, | ||||
|     repetition_penalty: Optional[float] = None, | ||||
|     repetition_context_size: Optional[int] = None, | ||||
|     top_p: Optional[float] = None, | ||||
|     min_p: Optional[float] = None, | ||||
|     min_tokens_to_keep: Optional[int] = None, | ||||
| ) -> Generator[Tuple[mx.array, mx.array], None, None]: | ||||
|     """ | ||||
|     A generator producing token ids based on the given prompt from the model. | ||||
| @@ -256,25 +251,17 @@ def generate_step( | ||||
|     elif len(prompt_cache) != len(model.layers): | ||||
|         raise ValueError("Wrong number of layers in the prompt cache.") | ||||
|  | ||||
|     if temp is not None or top_p is not None or min_tokens_to_keep is not None: | ||||
|         print( | ||||
|             "[Warning] Specifying sampling arguments to ``generate_step`` is " | ||||
|             "deprecated. Pass in a ``sampler`` instead." | ||||
|         ) | ||||
|     if repetition_penalty is not None: | ||||
|         print( | ||||
|             "[Warning] Specifying ``repetition_penalty`` is deprecated. " | ||||
|             "Pass in ``logits_processors`` instead." | ||||
|         ) | ||||
|  | ||||
|     sampler = sampler or make_sampler( | ||||
|         temp or 0.0, top_p or 0.0, min_p or 0.0, min_tokens_to_keep or 1 | ||||
|     ) | ||||
|     logits_processors = logits_processors or make_logits_processors( | ||||
|         None, repetition_penalty, repetition_context_size or 20 | ||||
|     ) | ||||
|     prompt_progress_callback = prompt_progress_callback or (lambda *_: None) | ||||
|  | ||||
|     quantize_cache_fn = functools.partial( | ||||
|         maybe_quantize_kv_cache, | ||||
|         quantized_kv_start=quantized_kv_start, | ||||
|         kv_group_size=kv_group_size, | ||||
|         kv_bits=kv_bits, | ||||
|     ) | ||||
|  | ||||
|     sampler = sampler or (lambda x: mx.argmax(x, axis=-1)) | ||||
|  | ||||
|     def _step(y): | ||||
|         with mx.stream(generation_stream): | ||||
|             logits = model(y[None], cache=prompt_cache) | ||||
| @@ -287,9 +274,7 @@ def generate_step( | ||||
|                 for processor in logits_processors: | ||||
|                     logits = processor(tokens, logits) | ||||
|  | ||||
|             maybe_quantize_kv_cache( | ||||
|                 prompt_cache, quantized_kv_start, kv_group_size, kv_bits | ||||
|             ) | ||||
|             quantize_cache_fn(prompt_cache) | ||||
|  | ||||
|             logprobs = logits - mx.logsumexp(logits, keepdims=True) | ||||
|             y = sampler(logprobs) | ||||
| @@ -300,9 +285,7 @@ def generate_step( | ||||
|         prompt_processed_tokens = 0 | ||||
|         while y.size > prefill_step_size: | ||||
|             model(y[:prefill_step_size][None], cache=prompt_cache) | ||||
|             maybe_quantize_kv_cache( | ||||
|                 prompt_cache, quantized_kv_start, kv_group_size, kv_bits | ||||
|             ) | ||||
|             quantize_cache_fn(prompt_cache) | ||||
|             mx.eval([c.state for c in prompt_cache]) | ||||
|             prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens) | ||||
|             prompt_processed_tokens += prefill_step_size | ||||
| @@ -329,10 +312,162 @@ def generate_step( | ||||
|         n += 1 | ||||
|  | ||||
|  | ||||
| def speculative_generate_step( | ||||
|     prompt: mx.array, | ||||
|     model: nn.Module, | ||||
|     draft_model: nn.Module, | ||||
|     *, | ||||
|     num_draft_tokens=2, | ||||
|     max_tokens: int = 256, | ||||
|     sampler: Optional[Callable[mx.array, mx.array]] = None, | ||||
|     logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, | ||||
|     prompt_cache: Optional[Any] = None, | ||||
|     prefill_step_size: int = 512, | ||||
|     kv_bits: Optional[int] = None, | ||||
|     kv_group_size: int = 64, | ||||
|     quantized_kv_start: int = 0, | ||||
| ) -> Generator[Tuple[mx.array, mx.array], None, None]: | ||||
|     """ | ||||
|     A generator producing token ids based on the given prompt from the model. | ||||
|  | ||||
|     Args: | ||||
|         prompt (mx.array): The input prompt. | ||||
|         model (nn.Module): The model to use for generation. | ||||
|         draft_model (nn.Module): The draft model for speculative decoding. | ||||
|         num_draft_tokens (int, optional): The number of draft tokens for | ||||
|           speculative decoding. Default: ``2``. | ||||
|         max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite | ||||
|           generator. Default: ``256``. | ||||
|         sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a | ||||
|           token from a vector of log probabilities. Default: ``None``. | ||||
|         logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional): | ||||
|           A list of functions that take tokens and logits and return the processed | ||||
|           logits. Default: ``None``. | ||||
|         prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if | ||||
|           provided, the cache will be updated in place. The cache must be trimmable. | ||||
|         prefill_step_size (int): Step size for processing the prompt. | ||||
|         kv_bits (int, optional): Number of bits to use for KV cache quantization. | ||||
|           None implies no cache quantization. Default: ``None``. | ||||
|         kv_group_size (int): Group size for KV cache quantization. Default: ``64``. | ||||
|         quantized_kv_start (int): Step to begin using a quantized KV cache. | ||||
|            when ``kv_bits`` is non-None. Default: ``0``. | ||||
|  | ||||
|     Yields: | ||||
|         Tuple[mx.array, mx.array]: One token and a vector of log probabilities. | ||||
|     """ | ||||
|  | ||||
|     y = prompt | ||||
|     tokens = None | ||||
|  | ||||
|     # Create the KV cache for generation | ||||
|     if prompt_cache is None: | ||||
|         model_cache = cache.make_prompt_cache(model) | ||||
|         draft_cache = cache.make_prompt_cache(draft_model) | ||||
|     elif len(prompt_cache) != (len(model.layers) + len(draft_model.layers)): | ||||
|         raise ValueError("Wrong number of layers in the prompt cache.") | ||||
|     else: | ||||
|         model_cache = prompt_cache[: len(model.layers)] | ||||
|         draft_cache = prompt_cache[len(model.layers) :] | ||||
|  | ||||
|     sampler = sampler or (lambda x: mx.argmax(x, axis=-1)) | ||||
|  | ||||
|     quantize_cache_fn = functools.partial( | ||||
|         maybe_quantize_kv_cache, | ||||
|         quantized_kv_start=quantized_kv_start, | ||||
|         kv_group_size=kv_group_size, | ||||
|         kv_bits=kv_bits, | ||||
|     ) | ||||
|  | ||||
|     def _step(model, cache, y, n_predict=1): | ||||
|         with mx.stream(generation_stream): | ||||
|             logits = model(y[None], cache=cache) | ||||
|             logits = logits[:, -n_predict:, :] | ||||
|  | ||||
|             quantize_cache_fn(cache) | ||||
|  | ||||
|             logprobs = logits - mx.logsumexp(logits, keepdims=True) | ||||
|             y = sampler(logprobs).squeeze(0) | ||||
|             return y, logprobs.squeeze(0) | ||||
|  | ||||
|     def _prefill(model, cache, y): | ||||
|         while y.size > prefill_step_size: | ||||
|             model(y[:prefill_step_size][None], cache=cache) | ||||
|             quantize_cache_fn(cache) | ||||
|             mx.eval([c.state for c in cache]) | ||||
|             y = y[prefill_step_size:] | ||||
|             mx.metal.clear_cache() | ||||
|         return y | ||||
|  | ||||
|     def _rewind_cache(num_draft, num_accept): | ||||
|         cache.trim_prompt_cache(model_cache, num_draft - num_accept) | ||||
|         cache.trim_prompt_cache(draft_cache, max(num_draft - num_accept - 1, 0)) | ||||
|  | ||||
|     def _draft_generate(y, num_draft): | ||||
|         if num_draft == 0: | ||||
|             return mx.array([], mx.uint32) | ||||
|         ys = [] | ||||
|         for _ in range(num_draft): | ||||
|             y, _ = _step(draft_model, draft_cache, y) | ||||
|             mx.async_eval(y) | ||||
|             ys.append(y) | ||||
|         return mx.concatenate(ys) | ||||
|  | ||||
|     with mx.stream(generation_stream): | ||||
|         draft_y = _prefill(draft_model, draft_cache, y) | ||||
|         y = _prefill(model, model_cache, y) | ||||
|  | ||||
|     ntoks = 0 | ||||
|     # Set these so the finally block doesn't raise | ||||
|     num_draft = 0 | ||||
|     n = 0 | ||||
|     try: | ||||
|         while True: | ||||
|             num_draft = min(max_tokens - ntoks, num_draft_tokens) | ||||
|             draft_tokens = _draft_generate(draft_y, num_draft) | ||||
|             y = mx.concatenate([y, draft_tokens]) | ||||
|  | ||||
|             tokens, logprobs = _step(model, model_cache, y, num_draft + 1) | ||||
|             mx.eval(tokens, draft_tokens) | ||||
|             draft_tokens = draft_tokens.tolist() | ||||
|             tokens = tokens.tolist() | ||||
|             n = 0 | ||||
|             while n < num_draft: | ||||
|                 tn, dtn, lpn = tokens[n], draft_tokens[n], logprobs[n] | ||||
|                 if tn != dtn: | ||||
|                     break | ||||
|                 n += 1 | ||||
|                 ntoks += 1 | ||||
|                 yield tn, lpn | ||||
|                 if ntoks == max_tokens: | ||||
|                     break | ||||
|             if ntoks < max_tokens: | ||||
|                 ntoks += 1 | ||||
|                 yield tokens[n], logprobs[n] | ||||
|  | ||||
|             if ntoks == max_tokens: | ||||
|                 break | ||||
|  | ||||
|             y = mx.array([tokens[n]], mx.uint32) | ||||
|             draft_y = y | ||||
|  | ||||
|             # If we accpeted all the draft tokens, include the last | ||||
|             # draft token in the next draft step since it hasn't been | ||||
|             # processed yet by the draft model | ||||
|             if n == num_draft: | ||||
|                 draft_y = mx.concatenate( | ||||
|                     [mx.array(draft_tokens[-1:], mx.uint32), draft_y] | ||||
|                 ) | ||||
|  | ||||
|             _rewind_cache(num_draft, n) | ||||
|     finally: | ||||
|         _rewind_cache(num_draft, n) | ||||
|  | ||||
|  | ||||
| def stream_generate( | ||||
|     model: nn.Module, | ||||
|     tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], | ||||
|     prompt: Union[str, mx.array, List[int]], | ||||
|     draft_model: Optional[nn.Module] = None, | ||||
|     **kwargs, | ||||
| ) -> Generator[GenerationResponse, None, None]: | ||||
|     """ | ||||
| @@ -341,7 +476,11 @@ def stream_generate( | ||||
|     Args: | ||||
|         model (nn.Module): The model to use for generation. | ||||
|         tokenizer (PreTrainedTokenizer): The tokenizer. | ||||
|         prompt (Union[str, mx.array, List[int]]): The input prompt string or integer tokens. | ||||
|         prompt (Union[str, mx.array, List[int]]): The input prompt string or | ||||
|           integer tokens. | ||||
|         draft_model (Optional[nn.Module]): An optional draft model. If provided | ||||
|           then speculative decoding is used. The draft model must use the same | ||||
|           tokenizer as the main model. Default: ``None``. | ||||
|         kwargs: The remaining options get passed to :func:`generate_step`. | ||||
|           See :func:`generate_step` for more details. | ||||
|  | ||||
| @@ -363,10 +502,18 @@ def stream_generate( | ||||
|  | ||||
|     detokenizer = tokenizer.detokenizer | ||||
|  | ||||
|     if draft_model is None: | ||||
|         kwargs.pop("num_draft_tokens", None) | ||||
|         token_generator = generate_step(prompt, model, **kwargs) | ||||
|     else: | ||||
|         kwargs.pop("max_kv_size", None) | ||||
|         token_generator = speculative_generate_step( | ||||
|             prompt, model, draft_model, **kwargs | ||||
|         ) | ||||
|     with wired_limit(model, [generation_stream]): | ||||
|         detokenizer.reset() | ||||
|         tic = time.perf_counter() | ||||
|         for n, (token, logprobs) in enumerate(generate_step(prompt, model, **kwargs)): | ||||
|         for n, (token, logprobs) in enumerate(token_generator): | ||||
|             if n == 0: | ||||
|                 prompt_time = time.perf_counter() - tic | ||||
|                 prompt_tps = prompt.size / prompt_time | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun