This commit is contained in:
Awni Hannun 2024-12-13 20:21:34 -08:00
parent e34ecb79b4
commit 8c0b4ee7f3
2 changed files with 42 additions and 2 deletions

View File

@ -131,6 +131,18 @@ def setup_arg_parser():
type=int,
default=DEFAULT_QUANTIZED_KV_START,
)
parser.add_argument(
"--draft-model",
type=str,
help="A model to be used for speculative decoding.",
default=None,
)
parser.add_argument(
"--num-draft-tokens",
type=int,
help="Number of tokens to draft when using speculative decoding.",
default=2,
)
return parser
@ -211,11 +223,16 @@ def main():
add_generation_prompt=True,
)
prompt = prompt[test_prompt.index("<query>") :]
prompt = tokenizer.encode(prompt, add_special_tokens=False)
else:
prompt = tokenizer.encode(prompt)
if args.draft_model is not None:
draft_model, draft_tokenizer = load(args.draft_model)
if draft_tokenizer.vocab_size != tokenizer.vocab_size:
raise ValueError("Draft model tokenizer does not match model tokenizer.")
else:
draft_model = None
sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
response = generate(
model,
@ -229,6 +246,8 @@ def main():
kv_bits=args.kv_bits,
kv_group_size=args.kv_group_size,
quantized_kv_start=args.quantized_kv_start,
draft_model=draft_model,
num_draft_tokens=args.num_draft_tokens,
)
if not args.verbose:
print(response)

View File

@ -319,6 +319,8 @@ def speculative_generate_step(
*,
num_draft_tokens=2,
max_tokens: int = 256,
sampler: Optional[Callable[mx.array, mx.array]] = None,
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
prompt_cache: Optional[Any] = None,
prefill_step_size: int = 512,
kv_bits: Optional[int] = None,
@ -336,6 +338,11 @@ def speculative_generate_step(
speculative decoding. Default: ``2``.
max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite
generator. Default: ``256``.
sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a
token from a vector of log probabilities. Default: ``None``.
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
A list of functions that take tokens and logits and return the processed
logits. Default: ``None``.
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
provided, the cache will be updated in place. The cache must be trimmable.
prefill_step_size (int): Step size for processing the prompt.
@ -362,6 +369,15 @@ def speculative_generate_step(
model_cache = prompt_cache[: len(model.layers)]
draft_cache = prompt_cache[len(model.layers) :]
sampler = sampler or (lambda x: mx.argmax(x, axis=-1))
quantize_cache_fn = functools.partial(
maybe_quantize_kv_cache,
quantized_kv_start=quantized_kv_start,
kv_group_size=kv_group_size,
kv_bits=kv_bits,
)
def _step(model, cache, y, n_predict=1):
with mx.stream(generation_stream):
logits = model(y[None], cache=cache)
@ -370,7 +386,7 @@ def speculative_generate_step(
quantize_cache_fn(cache)
logprobs = logits - mx.logsumexp(logits, keepdims=True)
y = mx.argmax(logprobs, axis=-1).squeeze(0)
y = sampler(logprobs).squeeze(0)
return y, logprobs.squeeze(0)
def _prefill(model, cache, y):
@ -401,6 +417,9 @@ def speculative_generate_step(
y = _prefill(model, model_cache, y)
ntoks = 0
# Set these so the finally block doesn't raise
num_draft = 0
n = 0
try:
while True:
num_draft = min(max_tokens - ntoks, num_draft_tokens)
@ -484,8 +503,10 @@ def stream_generate(
detokenizer = tokenizer.detokenizer
if draft_model is None:
kwargs.pop("num_draft_tokens")
token_generator = generate_step(prompt, model, **kwargs)
else:
kwargs.pop("max_kv_size")
token_generator = speculative_generate_step(
prompt, model, draft_model, **kwargs
)