This commit is contained in:
Awni Hannun 2024-12-13 20:21:34 -08:00
parent e34ecb79b4
commit 8c0b4ee7f3
2 changed files with 42 additions and 2 deletions

View File

@ -131,6 +131,18 @@ def setup_arg_parser():
type=int, type=int,
default=DEFAULT_QUANTIZED_KV_START, default=DEFAULT_QUANTIZED_KV_START,
) )
parser.add_argument(
"--draft-model",
type=str,
help="A model to be used for speculative decoding.",
default=None,
)
parser.add_argument(
"--num-draft-tokens",
type=int,
help="Number of tokens to draft when using speculative decoding.",
default=2,
)
return parser return parser
@ -211,11 +223,16 @@ def main():
add_generation_prompt=True, add_generation_prompt=True,
) )
prompt = prompt[test_prompt.index("<query>") :] prompt = prompt[test_prompt.index("<query>") :]
prompt = tokenizer.encode(prompt, add_special_tokens=False) prompt = tokenizer.encode(prompt, add_special_tokens=False)
else: else:
prompt = tokenizer.encode(prompt) prompt = tokenizer.encode(prompt)
if args.draft_model is not None:
draft_model, draft_tokenizer = load(args.draft_model)
if draft_tokenizer.vocab_size != tokenizer.vocab_size:
raise ValueError("Draft model tokenizer does not match model tokenizer.")
else:
draft_model = None
sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep) sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
response = generate( response = generate(
model, model,
@ -229,6 +246,8 @@ def main():
kv_bits=args.kv_bits, kv_bits=args.kv_bits,
kv_group_size=args.kv_group_size, kv_group_size=args.kv_group_size,
quantized_kv_start=args.quantized_kv_start, quantized_kv_start=args.quantized_kv_start,
draft_model=draft_model,
num_draft_tokens=args.num_draft_tokens,
) )
if not args.verbose: if not args.verbose:
print(response) print(response)

View File

@ -319,6 +319,8 @@ def speculative_generate_step(
*, *,
num_draft_tokens=2, num_draft_tokens=2,
max_tokens: int = 256, max_tokens: int = 256,
sampler: Optional[Callable[mx.array, mx.array]] = None,
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
prompt_cache: Optional[Any] = None, prompt_cache: Optional[Any] = None,
prefill_step_size: int = 512, prefill_step_size: int = 512,
kv_bits: Optional[int] = None, kv_bits: Optional[int] = None,
@ -336,6 +338,11 @@ def speculative_generate_step(
speculative decoding. Default: ``2``. speculative decoding. Default: ``2``.
max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite
generator. Default: ``256``. generator. Default: ``256``.
sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a
token from a vector of log probabilities. Default: ``None``.
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
A list of functions that take tokens and logits and return the processed
logits. Default: ``None``.
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
provided, the cache will be updated in place. The cache must be trimmable. provided, the cache will be updated in place. The cache must be trimmable.
prefill_step_size (int): Step size for processing the prompt. prefill_step_size (int): Step size for processing the prompt.
@ -362,6 +369,15 @@ def speculative_generate_step(
model_cache = prompt_cache[: len(model.layers)] model_cache = prompt_cache[: len(model.layers)]
draft_cache = prompt_cache[len(model.layers) :] draft_cache = prompt_cache[len(model.layers) :]
sampler = sampler or (lambda x: mx.argmax(x, axis=-1))
quantize_cache_fn = functools.partial(
maybe_quantize_kv_cache,
quantized_kv_start=quantized_kv_start,
kv_group_size=kv_group_size,
kv_bits=kv_bits,
)
def _step(model, cache, y, n_predict=1): def _step(model, cache, y, n_predict=1):
with mx.stream(generation_stream): with mx.stream(generation_stream):
logits = model(y[None], cache=cache) logits = model(y[None], cache=cache)
@ -370,7 +386,7 @@ def speculative_generate_step(
quantize_cache_fn(cache) quantize_cache_fn(cache)
logprobs = logits - mx.logsumexp(logits, keepdims=True) logprobs = logits - mx.logsumexp(logits, keepdims=True)
y = mx.argmax(logprobs, axis=-1).squeeze(0) y = sampler(logprobs).squeeze(0)
return y, logprobs.squeeze(0) return y, logprobs.squeeze(0)
def _prefill(model, cache, y): def _prefill(model, cache, y):
@ -401,6 +417,9 @@ def speculative_generate_step(
y = _prefill(model, model_cache, y) y = _prefill(model, model_cache, y)
ntoks = 0 ntoks = 0
# Set these so the finally block doesn't raise
num_draft = 0
n = 0
try: try:
while True: while True:
num_draft = min(max_tokens - ntoks, num_draft_tokens) num_draft = min(max_tokens - ntoks, num_draft_tokens)
@ -484,8 +503,10 @@ def stream_generate(
detokenizer = tokenizer.detokenizer detokenizer = tokenizer.detokenizer
if draft_model is None: if draft_model is None:
kwargs.pop("num_draft_tokens")
token_generator = generate_step(prompt, model, **kwargs) token_generator = generate_step(prompt, model, **kwargs)
else: else:
kwargs.pop("max_kv_size")
token_generator = speculative_generate_step( token_generator = speculative_generate_step(
prompt, model, draft_model, **kwargs prompt, model, draft_model, **kwargs
) )