add logits processor to spec gen (#1260)

This commit is contained in:
Awni Hannun 2025-02-13 19:19:53 -08:00 committed by GitHub
parent ec30dc3538
commit 7b07b14e67
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -382,8 +382,8 @@ def speculative_generate_step(
and a bool indicating if the token was generated by the draft model and a bool indicating if the token was generated by the draft model
""" """
y = prompt y = prompt.astype(mx.uint32)
tokens = None prev_tokens = None
# Create the KV cache for generation # Create the KV cache for generation
if prompt_cache is None: if prompt_cache is None:
@ -404,17 +404,41 @@ def speculative_generate_step(
kv_bits=kv_bits, kv_bits=kv_bits,
) )
def _process_and_sample(tokens, logits):
if logits_processors:
for processor in logits_processors:
logits = processor(tokens, logits)
logprobs = logits - mx.logsumexp(logits, keepdims=True)
logprobs = logprobs.squeeze(0)
y = sampler(logprobs)
return y, logprobs
def _step(model, cache, y, n_predict=1): def _step(model, cache, y, n_predict=1):
with mx.stream(generation_stream): with mx.stream(generation_stream):
logits = model(y[None], cache=cache) logits = model(y[None], cache=cache)
logits = logits[:, -n_predict:, :] logits = logits[:, -n_predict:, :]
quantize_cache_fn(cache) quantize_cache_fn(cache)
if logits_processors:
logprobs = logits - mx.logsumexp(logits, keepdims=True) nonlocal prev_tokens
logprobs = logprobs.squeeze(0) out_y, out_logprobs = [], []
y = sampler(logprobs) if n_predict > 1:
return y, logprobs y = y[: -(n_predict - 1)]
for i in range(n_predict):
prev_tokens = (
mx.concat([prev_tokens, y]) if prev_tokens is not None else y
)
y, logprobs = _process_and_sample(
prev_tokens, logits[:, i : i + 1, :]
)
out_y.append(y)
out_logprobs.append(logprobs)
return mx.concatenate(out_y, axis=0), mx.concatenate(
out_logprobs, axis=0
)
else:
return _process_and_sample(None, logits)
def _prefill(model, cache, y): def _prefill(model, cache, y):
while y.size > prefill_step_size: while y.size > prefill_step_size:
@ -451,9 +475,14 @@ def speculative_generate_step(
while True: while True:
num_draft = min(max_tokens - ntoks, num_draft_tokens) num_draft = min(max_tokens - ntoks, num_draft_tokens)
draft_tokens = _draft_generate(draft_y, num_draft) draft_tokens = _draft_generate(draft_y, num_draft)
if prev_tokens is not None:
prev_tokens = prev_tokens[
: prev_tokens.size - draft_y.size - num_draft + 1
]
y = mx.concatenate([y, draft_tokens]) y = mx.concatenate([y, draft_tokens])
tokens, logprobs = _step(model, model_cache, y, num_draft + 1) tokens, logprobs = _step(model, model_cache, y, num_draft + 1)
mx.eval(tokens, draft_tokens) mx.eval(tokens, draft_tokens)
draft_tokens = draft_tokens.tolist() draft_tokens = draft_tokens.tolist()
tokens = tokens.tolist() tokens = tokens.tolist()
@ -485,6 +514,8 @@ def speculative_generate_step(
[mx.array(draft_tokens[-1:], mx.uint32), draft_y] [mx.array(draft_tokens[-1:], mx.uint32), draft_y]
) )
if prev_tokens is not None and n < num_draft:
prev_tokens = prev_tokens[: -(num_draft - n)]
_rewind_cache(num_draft, n) _rewind_cache(num_draft, n)
finally: finally:
_rewind_cache(num_draft, n) _rewind_cache(num_draft, n)