mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 01:12:24 +08:00
add a speculative decoding generator
This commit is contained in:
parent
5cae0a60e6
commit
f01bc5fb17
@ -2,6 +2,7 @@
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
import functools
|
||||
import glob
|
||||
import importlib
|
||||
import json
|
||||
@ -207,12 +208,6 @@ def generate_step(
|
||||
kv_group_size: int = 64,
|
||||
quantized_kv_start: int = 0,
|
||||
prompt_progress_callback: Optional[Callable[int, int]] = None,
|
||||
temp: Optional[float] = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
repetition_context_size: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
min_p: Optional[float] = None,
|
||||
min_tokens_to_keep: Optional[int] = None,
|
||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||
"""
|
||||
A generator producing token ids based on the given prompt from the model.
|
||||
@ -256,25 +251,15 @@ def generate_step(
|
||||
elif len(prompt_cache) != len(model.layers):
|
||||
raise ValueError("Wrong number of layers in the prompt cache.")
|
||||
|
||||
if temp is not None or top_p is not None or min_tokens_to_keep is not None:
|
||||
print(
|
||||
"[Warning] Specifying sampling arguments to ``generate_step`` is "
|
||||
"deprecated. Pass in a ``sampler`` instead."
|
||||
)
|
||||
if repetition_penalty is not None:
|
||||
print(
|
||||
"[Warning] Specifying ``repetition_penalty`` is deprecated. "
|
||||
"Pass in ``logits_processors`` instead."
|
||||
)
|
||||
|
||||
sampler = sampler or make_sampler(
|
||||
temp or 0.0, top_p or 0.0, min_p or 0.0, min_tokens_to_keep or 1
|
||||
)
|
||||
logits_processors = logits_processors or make_logits_processors(
|
||||
None, repetition_penalty, repetition_context_size or 20
|
||||
)
|
||||
prompt_progress_callback = prompt_progress_callback or (lambda *_: None)
|
||||
|
||||
quantize_cache_fn = functools.partial(
|
||||
maybe_quantize_kv_cache,
|
||||
quantized_kv_start=quantized_kv_start,
|
||||
kv_group_size=kv_group_size,
|
||||
kv_bits=kv_bits,
|
||||
)
|
||||
|
||||
def _step(y):
|
||||
with mx.stream(generation_stream):
|
||||
logits = model(y[None], cache=prompt_cache)
|
||||
@ -287,9 +272,7 @@ def generate_step(
|
||||
for processor in logits_processors:
|
||||
logits = processor(tokens, logits)
|
||||
|
||||
maybe_quantize_kv_cache(
|
||||
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
|
||||
)
|
||||
quantize_cache_fn(prompt_cache)
|
||||
|
||||
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
||||
y = sampler(logprobs)
|
||||
@ -300,9 +283,7 @@ def generate_step(
|
||||
prompt_processed_tokens = 0
|
||||
while y.size > prefill_step_size:
|
||||
model(y[:prefill_step_size][None], cache=prompt_cache)
|
||||
maybe_quantize_kv_cache(
|
||||
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
|
||||
)
|
||||
quantize_cache_fn(prompt_cache)
|
||||
mx.eval([c.state for c in prompt_cache])
|
||||
prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens)
|
||||
prompt_processed_tokens += prefill_step_size
|
||||
@ -329,10 +310,143 @@ def generate_step(
|
||||
n += 1
|
||||
|
||||
|
||||
def speculative_generate_step(
|
||||
prompt: mx.array,
|
||||
model: nn.Module,
|
||||
draft_model: nn.Module,
|
||||
*,
|
||||
num_draft_tokens=2,
|
||||
max_tokens: int = 256,
|
||||
prompt_cache: Optional[Any] = None,
|
||||
prefill_step_size: int = 512,
|
||||
kv_bits: Optional[int] = None,
|
||||
kv_group_size: int = 64,
|
||||
quantized_kv_start: int = 0,
|
||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||
"""
|
||||
A generator producing token ids based on the given prompt from the model.
|
||||
|
||||
Args:
|
||||
prompt (mx.array): The input prompt.
|
||||
model (nn.Module): The model to use for generation.
|
||||
draft_model (nn.Module): The draft model for speculative decoding.
|
||||
num_draft_tokens (int, optional): The number of draft tokens for
|
||||
speculative decoding. Default: ``2``.
|
||||
max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite
|
||||
generator. Default: ``256``.
|
||||
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
|
||||
provided, the cache will be updated in place. The cache must be trimmable.
|
||||
prefill_step_size (int): Step size for processing the prompt.
|
||||
kv_bits (int, optional): Number of bits to use for KV cache quantization.
|
||||
None implies no cache quantization. Default: ``None``.
|
||||
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
|
||||
quantized_kv_start (int): Step to begin using a quantized KV cache.
|
||||
when ``kv_bits`` is non-None. Default: ``0``.
|
||||
|
||||
Yields:
|
||||
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
|
||||
"""
|
||||
|
||||
y = prompt
|
||||
tokens = None
|
||||
|
||||
# Create the KV cache for generation
|
||||
if prompt_cache is None:
|
||||
model_cache = cache.make_prompt_cache(model)
|
||||
draft_cache = cache.make_prompt_cache(draft_model)
|
||||
elif len(prompt_cache) != (len(model.layers) + len(draft_model.layers)):
|
||||
raise ValueError("Wrong number of layers in the prompt cache.")
|
||||
else:
|
||||
model_cache = prompt_cache[: len(model.layers)]
|
||||
draft_cache = prompt_cache[len(model.layers) :]
|
||||
|
||||
def _step(model, cache, y, n_predict=1):
|
||||
with mx.stream(generation_stream):
|
||||
logits = model(y[None], cache=cache)
|
||||
logits = logits[:, -n_predict:, :]
|
||||
|
||||
quantize_cache_fn(cache)
|
||||
|
||||
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
||||
y = mx.argmax(logprobs, axis=-1).squeeze(0)
|
||||
return y, logprobs.squeeze(0)
|
||||
|
||||
def _prefill(model, cache, y):
|
||||
while y.size > prefill_step_size:
|
||||
model(y[:prefill_step_size][None], cache=cache)
|
||||
quantize_cache_fn(cache)
|
||||
mx.eval([c.state for c in cache])
|
||||
y = y[prefill_step_size:]
|
||||
mx.metal.clear_cache()
|
||||
return y
|
||||
|
||||
def _rewind_cache(num_draft, num_accept):
|
||||
cache.trim_prompt_cache(model_cache, num_draft - num_accept)
|
||||
cache.trim_prompt_cache(draft_cache, max(num_draft - num_accept - 1, 0))
|
||||
|
||||
def _draft_generate(y, num_draft):
|
||||
if num_draft == 0:
|
||||
return mx.array([], mx.uint32)
|
||||
ys = []
|
||||
for _ in range(num_draft):
|
||||
y, _ = _step(draft_model, draft_cache, y)
|
||||
mx.async_eval(y)
|
||||
ys.append(y)
|
||||
return mx.concatenate(ys)
|
||||
|
||||
with mx.stream(generation_stream):
|
||||
draft_y = _prefill(draft_model, draft_cache, y)
|
||||
y = _prefill(model, model_cache, y)
|
||||
|
||||
ntoks = 0
|
||||
try:
|
||||
while True:
|
||||
num_draft = min(max_tokens - ntoks, num_draft_tokens)
|
||||
draft_tokens = _draft_generate(draft_y, num_draft)
|
||||
y = mx.concatenate([y, draft_tokens])
|
||||
|
||||
tokens, logprobs = _step(model, model_cache, y, num_draft + 1)
|
||||
mx.eval(tokens, draft_tokens)
|
||||
draft_tokens = draft_tokens.tolist()
|
||||
tokens = tokens.tolist()
|
||||
n = 0
|
||||
while n < num_draft:
|
||||
tn, dtn, lpn = tokens[n], draft_tokens[n], logprobs[n]
|
||||
if tn != dtn:
|
||||
break
|
||||
n += 1
|
||||
ntoks += 1
|
||||
yield tn, lpn
|
||||
if ntoks == max_tokens:
|
||||
break
|
||||
if ntoks < max_tokens:
|
||||
ntoks += 1
|
||||
yield tokens[n], logprobs[n]
|
||||
|
||||
if ntoks == max_tokens:
|
||||
break
|
||||
|
||||
y = mx.array([tokens[n]], mx.uint32)
|
||||
draft_y = y
|
||||
|
||||
# If we accpeted all the draft tokens, include the last
|
||||
# draft token in the next draft step since it hasn't been
|
||||
# processed yet by the draft model
|
||||
if n == num_draft:
|
||||
draft_y = mx.concatenate(
|
||||
[mx.array(draft_tokens[-1:], mx.uint32), draft_y]
|
||||
)
|
||||
|
||||
_rewind_cache(num_draft, n)
|
||||
finally:
|
||||
_rewind_cache(num_draft, n)
|
||||
|
||||
|
||||
def stream_generate(
|
||||
model: nn.Module,
|
||||
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
||||
prompt: Union[str, mx.array, List[int]],
|
||||
draft_model: Optional[nn.Module] = None,
|
||||
**kwargs,
|
||||
) -> Generator[GenerationResponse, None, None]:
|
||||
"""
|
||||
@ -341,7 +455,11 @@ def stream_generate(
|
||||
Args:
|
||||
model (nn.Module): The model to use for generation.
|
||||
tokenizer (PreTrainedTokenizer): The tokenizer.
|
||||
prompt (Union[str, mx.array, List[int]]): The input prompt string or integer tokens.
|
||||
prompt (Union[str, mx.array, List[int]]): The input prompt string or
|
||||
integer tokens.
|
||||
draft_model (Optional[nn.Module]): An optional draft model. If provided
|
||||
then speculative decoding is used. The draft model must use the same
|
||||
tokenizer as the main model. Default: ``None``.
|
||||
kwargs: The remaining options get passed to :func:`generate_step`.
|
||||
See :func:`generate_step` for more details.
|
||||
|
||||
@ -363,10 +481,16 @@ def stream_generate(
|
||||
|
||||
detokenizer = tokenizer.detokenizer
|
||||
|
||||
if draft_model is None:
|
||||
token_generator = generate_step(prompt, model, **kwargs)
|
||||
else:
|
||||
token_generator = speculative_generate_step(
|
||||
prompt, model, draft_model, **kwargs
|
||||
)
|
||||
with wired_limit(model, [generation_stream]):
|
||||
detokenizer.reset()
|
||||
tic = time.perf_counter()
|
||||
for n, (token, logprobs) in enumerate(generate_step(prompt, model, **kwargs)):
|
||||
for n, (token, logprobs) in enumerate(token_generator):
|
||||
if n == 0:
|
||||
prompt_time = time.perf_counter() - tic
|
||||
prompt_tps = prompt.size / prompt_time
|
||||
|
Loading…
Reference in New Issue
Block a user