add a speculative decoding generator

This commit is contained in:
Awni Hannun 2024-12-13 17:57:44 -08:00
parent 5cae0a60e6
commit f01bc5fb17

View File

@ -2,6 +2,7 @@
import contextlib import contextlib
import copy import copy
import functools
import glob import glob
import importlib import importlib
import json import json
@ -207,12 +208,6 @@ def generate_step(
kv_group_size: int = 64, kv_group_size: int = 64,
quantized_kv_start: int = 0, quantized_kv_start: int = 0,
prompt_progress_callback: Optional[Callable[int, int]] = None, prompt_progress_callback: Optional[Callable[int, int]] = None,
temp: Optional[float] = None,
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = None,
top_p: Optional[float] = None,
min_p: Optional[float] = None,
min_tokens_to_keep: Optional[int] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]: ) -> Generator[Tuple[mx.array, mx.array], None, None]:
""" """
A generator producing token ids based on the given prompt from the model. A generator producing token ids based on the given prompt from the model.
@ -256,25 +251,15 @@ def generate_step(
elif len(prompt_cache) != len(model.layers): elif len(prompt_cache) != len(model.layers):
raise ValueError("Wrong number of layers in the prompt cache.") raise ValueError("Wrong number of layers in the prompt cache.")
if temp is not None or top_p is not None or min_tokens_to_keep is not None:
print(
"[Warning] Specifying sampling arguments to ``generate_step`` is "
"deprecated. Pass in a ``sampler`` instead."
)
if repetition_penalty is not None:
print(
"[Warning] Specifying ``repetition_penalty`` is deprecated. "
"Pass in ``logits_processors`` instead."
)
sampler = sampler or make_sampler(
temp or 0.0, top_p or 0.0, min_p or 0.0, min_tokens_to_keep or 1
)
logits_processors = logits_processors or make_logits_processors(
None, repetition_penalty, repetition_context_size or 20
)
prompt_progress_callback = prompt_progress_callback or (lambda *_: None) prompt_progress_callback = prompt_progress_callback or (lambda *_: None)
quantize_cache_fn = functools.partial(
maybe_quantize_kv_cache,
quantized_kv_start=quantized_kv_start,
kv_group_size=kv_group_size,
kv_bits=kv_bits,
)
def _step(y): def _step(y):
with mx.stream(generation_stream): with mx.stream(generation_stream):
logits = model(y[None], cache=prompt_cache) logits = model(y[None], cache=prompt_cache)
@ -287,9 +272,7 @@ def generate_step(
for processor in logits_processors: for processor in logits_processors:
logits = processor(tokens, logits) logits = processor(tokens, logits)
maybe_quantize_kv_cache( quantize_cache_fn(prompt_cache)
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
)
logprobs = logits - mx.logsumexp(logits, keepdims=True) logprobs = logits - mx.logsumexp(logits, keepdims=True)
y = sampler(logprobs) y = sampler(logprobs)
@ -300,9 +283,7 @@ def generate_step(
prompt_processed_tokens = 0 prompt_processed_tokens = 0
while y.size > prefill_step_size: while y.size > prefill_step_size:
model(y[:prefill_step_size][None], cache=prompt_cache) model(y[:prefill_step_size][None], cache=prompt_cache)
maybe_quantize_kv_cache( quantize_cache_fn(prompt_cache)
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
)
mx.eval([c.state for c in prompt_cache]) mx.eval([c.state for c in prompt_cache])
prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens) prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens)
prompt_processed_tokens += prefill_step_size prompt_processed_tokens += prefill_step_size
@ -329,10 +310,143 @@ def generate_step(
n += 1 n += 1
def speculative_generate_step(
prompt: mx.array,
model: nn.Module,
draft_model: nn.Module,
*,
num_draft_tokens=2,
max_tokens: int = 256,
prompt_cache: Optional[Any] = None,
prefill_step_size: int = 512,
kv_bits: Optional[int] = None,
kv_group_size: int = 64,
quantized_kv_start: int = 0,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
"""
A generator producing token ids based on the given prompt from the model.
Args:
prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation.
draft_model (nn.Module): The draft model for speculative decoding.
num_draft_tokens (int, optional): The number of draft tokens for
speculative decoding. Default: ``2``.
max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite
generator. Default: ``256``.
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
provided, the cache will be updated in place. The cache must be trimmable.
prefill_step_size (int): Step size for processing the prompt.
kv_bits (int, optional): Number of bits to use for KV cache quantization.
None implies no cache quantization. Default: ``None``.
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
quantized_kv_start (int): Step to begin using a quantized KV cache.
when ``kv_bits`` is non-None. Default: ``0``.
Yields:
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
"""
y = prompt
tokens = None
# Create the KV cache for generation
if prompt_cache is None:
model_cache = cache.make_prompt_cache(model)
draft_cache = cache.make_prompt_cache(draft_model)
elif len(prompt_cache) != (len(model.layers) + len(draft_model.layers)):
raise ValueError("Wrong number of layers in the prompt cache.")
else:
model_cache = prompt_cache[: len(model.layers)]
draft_cache = prompt_cache[len(model.layers) :]
def _step(model, cache, y, n_predict=1):
with mx.stream(generation_stream):
logits = model(y[None], cache=cache)
logits = logits[:, -n_predict:, :]
quantize_cache_fn(cache)
logprobs = logits - mx.logsumexp(logits, keepdims=True)
y = mx.argmax(logprobs, axis=-1).squeeze(0)
return y, logprobs.squeeze(0)
def _prefill(model, cache, y):
while y.size > prefill_step_size:
model(y[:prefill_step_size][None], cache=cache)
quantize_cache_fn(cache)
mx.eval([c.state for c in cache])
y = y[prefill_step_size:]
mx.metal.clear_cache()
return y
def _rewind_cache(num_draft, num_accept):
cache.trim_prompt_cache(model_cache, num_draft - num_accept)
cache.trim_prompt_cache(draft_cache, max(num_draft - num_accept - 1, 0))
def _draft_generate(y, num_draft):
if num_draft == 0:
return mx.array([], mx.uint32)
ys = []
for _ in range(num_draft):
y, _ = _step(draft_model, draft_cache, y)
mx.async_eval(y)
ys.append(y)
return mx.concatenate(ys)
with mx.stream(generation_stream):
draft_y = _prefill(draft_model, draft_cache, y)
y = _prefill(model, model_cache, y)
ntoks = 0
try:
while True:
num_draft = min(max_tokens - ntoks, num_draft_tokens)
draft_tokens = _draft_generate(draft_y, num_draft)
y = mx.concatenate([y, draft_tokens])
tokens, logprobs = _step(model, model_cache, y, num_draft + 1)
mx.eval(tokens, draft_tokens)
draft_tokens = draft_tokens.tolist()
tokens = tokens.tolist()
n = 0
while n < num_draft:
tn, dtn, lpn = tokens[n], draft_tokens[n], logprobs[n]
if tn != dtn:
break
n += 1
ntoks += 1
yield tn, lpn
if ntoks == max_tokens:
break
if ntoks < max_tokens:
ntoks += 1
yield tokens[n], logprobs[n]
if ntoks == max_tokens:
break
y = mx.array([tokens[n]], mx.uint32)
draft_y = y
# If we accpeted all the draft tokens, include the last
# draft token in the next draft step since it hasn't been
# processed yet by the draft model
if n == num_draft:
draft_y = mx.concatenate(
[mx.array(draft_tokens[-1:], mx.uint32), draft_y]
)
_rewind_cache(num_draft, n)
finally:
_rewind_cache(num_draft, n)
def stream_generate( def stream_generate(
model: nn.Module, model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: Union[str, mx.array, List[int]], prompt: Union[str, mx.array, List[int]],
draft_model: Optional[nn.Module] = None,
**kwargs, **kwargs,
) -> Generator[GenerationResponse, None, None]: ) -> Generator[GenerationResponse, None, None]:
""" """
@ -341,7 +455,11 @@ def stream_generate(
Args: Args:
model (nn.Module): The model to use for generation. model (nn.Module): The model to use for generation.
tokenizer (PreTrainedTokenizer): The tokenizer. tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (Union[str, mx.array, List[int]]): The input prompt string or integer tokens. prompt (Union[str, mx.array, List[int]]): The input prompt string or
integer tokens.
draft_model (Optional[nn.Module]): An optional draft model. If provided
then speculative decoding is used. The draft model must use the same
tokenizer as the main model. Default: ``None``.
kwargs: The remaining options get passed to :func:`generate_step`. kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details. See :func:`generate_step` for more details.
@ -363,10 +481,16 @@ def stream_generate(
detokenizer = tokenizer.detokenizer detokenizer = tokenizer.detokenizer
if draft_model is None:
token_generator = generate_step(prompt, model, **kwargs)
else:
token_generator = speculative_generate_step(
prompt, model, draft_model, **kwargs
)
with wired_limit(model, [generation_stream]): with wired_limit(model, [generation_stream]):
detokenizer.reset() detokenizer.reset()
tic = time.perf_counter() tic = time.perf_counter()
for n, (token, logprobs) in enumerate(generate_step(prompt, model, **kwargs)): for n, (token, logprobs) in enumerate(token_generator):
if n == 0: if n == 0:
prompt_time = time.perf_counter() - tic prompt_time = time.perf_counter() - tic
prompt_tps = prompt.size / prompt_time prompt_tps = prompt.size / prompt_time