mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Fix logits processor bugs with spec dec (#1291)
* Fix logits processor bugs with spec dec * bump patch
This commit is contained in:
parent
85669451d0
commit
3d793ecf68
@ -1,3 +1,3 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
__version__ = "0.21.0"
|
__version__ = "0.21.5"
|
||||||
|
@ -409,8 +409,7 @@ def speculative_generate_step(
|
|||||||
for processor in logits_processors:
|
for processor in logits_processors:
|
||||||
logits = processor(tokens, logits)
|
logits = processor(tokens, logits)
|
||||||
|
|
||||||
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
|
||||||
logprobs = logprobs.squeeze(0)
|
|
||||||
y = sampler(logprobs)
|
y = sampler(logprobs)
|
||||||
return y, logprobs
|
return y, logprobs
|
||||||
|
|
||||||
@ -429,16 +428,14 @@ def speculative_generate_step(
|
|||||||
prev_tokens = (
|
prev_tokens = (
|
||||||
mx.concat([prev_tokens, y]) if prev_tokens is not None else y
|
mx.concat([prev_tokens, y]) if prev_tokens is not None else y
|
||||||
)
|
)
|
||||||
y, logprobs = _process_and_sample(
|
y, logprobs = _process_and_sample(prev_tokens, logits[:, i, :])
|
||||||
prev_tokens, logits[:, i : i + 1, :]
|
|
||||||
)
|
|
||||||
out_y.append(y)
|
out_y.append(y)
|
||||||
out_logprobs.append(logprobs)
|
out_logprobs.append(logprobs)
|
||||||
return mx.concatenate(out_y, axis=0), mx.concatenate(
|
return mx.concatenate(out_y, axis=0), mx.concatenate(
|
||||||
out_logprobs, axis=0
|
out_logprobs, axis=0
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return _process_and_sample(None, logits)
|
return _process_and_sample(None, logits.squeeze(0))
|
||||||
|
|
||||||
def _prefill(model, cache, y):
|
def _prefill(model, cache, y):
|
||||||
while y.size > prefill_step_size:
|
while y.size > prefill_step_size:
|
||||||
@ -476,13 +473,9 @@ def speculative_generate_step(
|
|||||||
num_draft = min(max_tokens - ntoks, num_draft_tokens)
|
num_draft = min(max_tokens - ntoks, num_draft_tokens)
|
||||||
draft_tokens = _draft_generate(draft_y, num_draft)
|
draft_tokens = _draft_generate(draft_y, num_draft)
|
||||||
if prev_tokens is not None:
|
if prev_tokens is not None:
|
||||||
prev_tokens = prev_tokens[
|
prev_tokens = prev_tokens[: prev_tokens.size - y.size - num_draft + 1]
|
||||||
: prev_tokens.size - draft_y.size - num_draft + 1
|
|
||||||
]
|
|
||||||
y = mx.concatenate([y, draft_tokens])
|
y = mx.concatenate([y, draft_tokens])
|
||||||
|
|
||||||
tokens, logprobs = _step(model, model_cache, y, num_draft + 1)
|
tokens, logprobs = _step(model, model_cache, y, num_draft + 1)
|
||||||
|
|
||||||
mx.eval(tokens, draft_tokens)
|
mx.eval(tokens, draft_tokens)
|
||||||
draft_tokens = draft_tokens.tolist()
|
draft_tokens = draft_tokens.tolist()
|
||||||
tokens = tokens.tolist()
|
tokens = tokens.tolist()
|
||||||
@ -514,8 +507,8 @@ def speculative_generate_step(
|
|||||||
[mx.array(draft_tokens[-1:], mx.uint32), draft_y]
|
[mx.array(draft_tokens[-1:], mx.uint32), draft_y]
|
||||||
)
|
)
|
||||||
|
|
||||||
if prev_tokens is not None and n < num_draft:
|
if prev_tokens is not None:
|
||||||
prev_tokens = prev_tokens[: -(num_draft - n)]
|
prev_tokens = prev_tokens[: -max(num_draft - n, 1)]
|
||||||
_rewind_cache(num_draft, n)
|
_rewind_cache(num_draft, n)
|
||||||
finally:
|
finally:
|
||||||
_rewind_cache(num_draft, n)
|
_rewind_cache(num_draft, n)
|
||||||
|
Loading…
Reference in New Issue
Block a user