mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Add a speculative decoding generator (#1155)
* add a speculative decoding generator * fix * fixes * optional kwarg pop
This commit is contained in:
parent
5cae0a60e6
commit
93c5cfd781
@ -131,6 +131,18 @@ def setup_arg_parser():
|
||||
type=int,
|
||||
default=DEFAULT_QUANTIZED_KV_START,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--draft-model",
|
||||
type=str,
|
||||
help="A model to be used for speculative decoding.",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-draft-tokens",
|
||||
type=int,
|
||||
help="Number of tokens to draft when using speculative decoding.",
|
||||
default=2,
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
@ -211,11 +223,16 @@ def main():
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
prompt = prompt[test_prompt.index("<query>") :]
|
||||
|
||||
prompt = tokenizer.encode(prompt, add_special_tokens=False)
|
||||
else:
|
||||
prompt = tokenizer.encode(prompt)
|
||||
|
||||
if args.draft_model is not None:
|
||||
draft_model, draft_tokenizer = load(args.draft_model)
|
||||
if draft_tokenizer.vocab_size != tokenizer.vocab_size:
|
||||
raise ValueError("Draft model tokenizer does not match model tokenizer.")
|
||||
else:
|
||||
draft_model = None
|
||||
sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
|
||||
response = generate(
|
||||
model,
|
||||
@ -229,6 +246,8 @@ def main():
|
||||
kv_bits=args.kv_bits,
|
||||
kv_group_size=args.kv_group_size,
|
||||
quantized_kv_start=args.quantized_kv_start,
|
||||
draft_model=draft_model,
|
||||
num_draft_tokens=args.num_draft_tokens,
|
||||
)
|
||||
if not args.verbose:
|
||||
print(response)
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
import functools
|
||||
import glob
|
||||
import importlib
|
||||
import json
|
||||
@ -207,12 +208,6 @@ def generate_step(
|
||||
kv_group_size: int = 64,
|
||||
quantized_kv_start: int = 0,
|
||||
prompt_progress_callback: Optional[Callable[int, int]] = None,
|
||||
temp: Optional[float] = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
repetition_context_size: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
min_p: Optional[float] = None,
|
||||
min_tokens_to_keep: Optional[int] = None,
|
||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||
"""
|
||||
A generator producing token ids based on the given prompt from the model.
|
||||
@ -256,25 +251,17 @@ def generate_step(
|
||||
elif len(prompt_cache) != len(model.layers):
|
||||
raise ValueError("Wrong number of layers in the prompt cache.")
|
||||
|
||||
if temp is not None or top_p is not None or min_tokens_to_keep is not None:
|
||||
print(
|
||||
"[Warning] Specifying sampling arguments to ``generate_step`` is "
|
||||
"deprecated. Pass in a ``sampler`` instead."
|
||||
)
|
||||
if repetition_penalty is not None:
|
||||
print(
|
||||
"[Warning] Specifying ``repetition_penalty`` is deprecated. "
|
||||
"Pass in ``logits_processors`` instead."
|
||||
)
|
||||
|
||||
sampler = sampler or make_sampler(
|
||||
temp or 0.0, top_p or 0.0, min_p or 0.0, min_tokens_to_keep or 1
|
||||
)
|
||||
logits_processors = logits_processors or make_logits_processors(
|
||||
None, repetition_penalty, repetition_context_size or 20
|
||||
)
|
||||
prompt_progress_callback = prompt_progress_callback or (lambda *_: None)
|
||||
|
||||
quantize_cache_fn = functools.partial(
|
||||
maybe_quantize_kv_cache,
|
||||
quantized_kv_start=quantized_kv_start,
|
||||
kv_group_size=kv_group_size,
|
||||
kv_bits=kv_bits,
|
||||
)
|
||||
|
||||
sampler = sampler or (lambda x: mx.argmax(x, axis=-1))
|
||||
|
||||
def _step(y):
|
||||
with mx.stream(generation_stream):
|
||||
logits = model(y[None], cache=prompt_cache)
|
||||
@ -287,9 +274,7 @@ def generate_step(
|
||||
for processor in logits_processors:
|
||||
logits = processor(tokens, logits)
|
||||
|
||||
maybe_quantize_kv_cache(
|
||||
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
|
||||
)
|
||||
quantize_cache_fn(prompt_cache)
|
||||
|
||||
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
||||
y = sampler(logprobs)
|
||||
@ -300,9 +285,7 @@ def generate_step(
|
||||
prompt_processed_tokens = 0
|
||||
while y.size > prefill_step_size:
|
||||
model(y[:prefill_step_size][None], cache=prompt_cache)
|
||||
maybe_quantize_kv_cache(
|
||||
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
|
||||
)
|
||||
quantize_cache_fn(prompt_cache)
|
||||
mx.eval([c.state for c in prompt_cache])
|
||||
prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens)
|
||||
prompt_processed_tokens += prefill_step_size
|
||||
@ -329,10 +312,162 @@ def generate_step(
|
||||
n += 1
|
||||
|
||||
|
||||
def speculative_generate_step(
|
||||
prompt: mx.array,
|
||||
model: nn.Module,
|
||||
draft_model: nn.Module,
|
||||
*,
|
||||
num_draft_tokens=2,
|
||||
max_tokens: int = 256,
|
||||
sampler: Optional[Callable[mx.array, mx.array]] = None,
|
||||
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
||||
prompt_cache: Optional[Any] = None,
|
||||
prefill_step_size: int = 512,
|
||||
kv_bits: Optional[int] = None,
|
||||
kv_group_size: int = 64,
|
||||
quantized_kv_start: int = 0,
|
||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||
"""
|
||||
A generator producing token ids based on the given prompt from the model.
|
||||
|
||||
Args:
|
||||
prompt (mx.array): The input prompt.
|
||||
model (nn.Module): The model to use for generation.
|
||||
draft_model (nn.Module): The draft model for speculative decoding.
|
||||
num_draft_tokens (int, optional): The number of draft tokens for
|
||||
speculative decoding. Default: ``2``.
|
||||
max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite
|
||||
generator. Default: ``256``.
|
||||
sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a
|
||||
token from a vector of log probabilities. Default: ``None``.
|
||||
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
||||
A list of functions that take tokens and logits and return the processed
|
||||
logits. Default: ``None``.
|
||||
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
|
||||
provided, the cache will be updated in place. The cache must be trimmable.
|
||||
prefill_step_size (int): Step size for processing the prompt.
|
||||
kv_bits (int, optional): Number of bits to use for KV cache quantization.
|
||||
None implies no cache quantization. Default: ``None``.
|
||||
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
|
||||
quantized_kv_start (int): Step to begin using a quantized KV cache.
|
||||
when ``kv_bits`` is non-None. Default: ``0``.
|
||||
|
||||
Yields:
|
||||
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
|
||||
"""
|
||||
|
||||
y = prompt
|
||||
tokens = None
|
||||
|
||||
# Create the KV cache for generation
|
||||
if prompt_cache is None:
|
||||
model_cache = cache.make_prompt_cache(model)
|
||||
draft_cache = cache.make_prompt_cache(draft_model)
|
||||
elif len(prompt_cache) != (len(model.layers) + len(draft_model.layers)):
|
||||
raise ValueError("Wrong number of layers in the prompt cache.")
|
||||
else:
|
||||
model_cache = prompt_cache[: len(model.layers)]
|
||||
draft_cache = prompt_cache[len(model.layers) :]
|
||||
|
||||
sampler = sampler or (lambda x: mx.argmax(x, axis=-1))
|
||||
|
||||
quantize_cache_fn = functools.partial(
|
||||
maybe_quantize_kv_cache,
|
||||
quantized_kv_start=quantized_kv_start,
|
||||
kv_group_size=kv_group_size,
|
||||
kv_bits=kv_bits,
|
||||
)
|
||||
|
||||
def _step(model, cache, y, n_predict=1):
|
||||
with mx.stream(generation_stream):
|
||||
logits = model(y[None], cache=cache)
|
||||
logits = logits[:, -n_predict:, :]
|
||||
|
||||
quantize_cache_fn(cache)
|
||||
|
||||
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
||||
y = sampler(logprobs).squeeze(0)
|
||||
return y, logprobs.squeeze(0)
|
||||
|
||||
def _prefill(model, cache, y):
|
||||
while y.size > prefill_step_size:
|
||||
model(y[:prefill_step_size][None], cache=cache)
|
||||
quantize_cache_fn(cache)
|
||||
mx.eval([c.state for c in cache])
|
||||
y = y[prefill_step_size:]
|
||||
mx.metal.clear_cache()
|
||||
return y
|
||||
|
||||
def _rewind_cache(num_draft, num_accept):
|
||||
cache.trim_prompt_cache(model_cache, num_draft - num_accept)
|
||||
cache.trim_prompt_cache(draft_cache, max(num_draft - num_accept - 1, 0))
|
||||
|
||||
def _draft_generate(y, num_draft):
|
||||
if num_draft == 0:
|
||||
return mx.array([], mx.uint32)
|
||||
ys = []
|
||||
for _ in range(num_draft):
|
||||
y, _ = _step(draft_model, draft_cache, y)
|
||||
mx.async_eval(y)
|
||||
ys.append(y)
|
||||
return mx.concatenate(ys)
|
||||
|
||||
with mx.stream(generation_stream):
|
||||
draft_y = _prefill(draft_model, draft_cache, y)
|
||||
y = _prefill(model, model_cache, y)
|
||||
|
||||
ntoks = 0
|
||||
# Set these so the finally block doesn't raise
|
||||
num_draft = 0
|
||||
n = 0
|
||||
try:
|
||||
while True:
|
||||
num_draft = min(max_tokens - ntoks, num_draft_tokens)
|
||||
draft_tokens = _draft_generate(draft_y, num_draft)
|
||||
y = mx.concatenate([y, draft_tokens])
|
||||
|
||||
tokens, logprobs = _step(model, model_cache, y, num_draft + 1)
|
||||
mx.eval(tokens, draft_tokens)
|
||||
draft_tokens = draft_tokens.tolist()
|
||||
tokens = tokens.tolist()
|
||||
n = 0
|
||||
while n < num_draft:
|
||||
tn, dtn, lpn = tokens[n], draft_tokens[n], logprobs[n]
|
||||
if tn != dtn:
|
||||
break
|
||||
n += 1
|
||||
ntoks += 1
|
||||
yield tn, lpn
|
||||
if ntoks == max_tokens:
|
||||
break
|
||||
if ntoks < max_tokens:
|
||||
ntoks += 1
|
||||
yield tokens[n], logprobs[n]
|
||||
|
||||
if ntoks == max_tokens:
|
||||
break
|
||||
|
||||
y = mx.array([tokens[n]], mx.uint32)
|
||||
draft_y = y
|
||||
|
||||
# If we accpeted all the draft tokens, include the last
|
||||
# draft token in the next draft step since it hasn't been
|
||||
# processed yet by the draft model
|
||||
if n == num_draft:
|
||||
draft_y = mx.concatenate(
|
||||
[mx.array(draft_tokens[-1:], mx.uint32), draft_y]
|
||||
)
|
||||
|
||||
_rewind_cache(num_draft, n)
|
||||
finally:
|
||||
_rewind_cache(num_draft, n)
|
||||
|
||||
|
||||
def stream_generate(
|
||||
model: nn.Module,
|
||||
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
||||
prompt: Union[str, mx.array, List[int]],
|
||||
draft_model: Optional[nn.Module] = None,
|
||||
**kwargs,
|
||||
) -> Generator[GenerationResponse, None, None]:
|
||||
"""
|
||||
@ -341,7 +476,11 @@ def stream_generate(
|
||||
Args:
|
||||
model (nn.Module): The model to use for generation.
|
||||
tokenizer (PreTrainedTokenizer): The tokenizer.
|
||||
prompt (Union[str, mx.array, List[int]]): The input prompt string or integer tokens.
|
||||
prompt (Union[str, mx.array, List[int]]): The input prompt string or
|
||||
integer tokens.
|
||||
draft_model (Optional[nn.Module]): An optional draft model. If provided
|
||||
then speculative decoding is used. The draft model must use the same
|
||||
tokenizer as the main model. Default: ``None``.
|
||||
kwargs: The remaining options get passed to :func:`generate_step`.
|
||||
See :func:`generate_step` for more details.
|
||||
|
||||
@ -363,10 +502,18 @@ def stream_generate(
|
||||
|
||||
detokenizer = tokenizer.detokenizer
|
||||
|
||||
if draft_model is None:
|
||||
kwargs.pop("num_draft_tokens", None)
|
||||
token_generator = generate_step(prompt, model, **kwargs)
|
||||
else:
|
||||
kwargs.pop("max_kv_size", None)
|
||||
token_generator = speculative_generate_step(
|
||||
prompt, model, draft_model, **kwargs
|
||||
)
|
||||
with wired_limit(model, [generation_stream]):
|
||||
detokenizer.reset()
|
||||
tic = time.perf_counter()
|
||||
for n, (token, logprobs) in enumerate(generate_step(prompt, model, **kwargs)):
|
||||
for n, (token, logprobs) in enumerate(token_generator):
|
||||
if n == 0:
|
||||
prompt_time = time.perf_counter() - tic
|
||||
prompt_tps = prompt.size / prompt_time
|
||||
|
Loading…
Reference in New Issue
Block a user