From 3d793ecf6887512fd81f0f0c6bd156c06a6eaf88 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 20 Feb 2025 15:55:55 -0800 Subject: [PATCH] Fix logits processor bugs with spec dec (#1291) * Fix logits processor bugs with spec dec * bump patch --- llms/mlx_lm/_version.py | 2 +- llms/mlx_lm/utils.py | 19 ++++++------------- 2 files changed, 7 insertions(+), 14 deletions(-) diff --git a/llms/mlx_lm/_version.py b/llms/mlx_lm/_version.py index b2f98e6f..89e6cd00 100644 --- a/llms/mlx_lm/_version.py +++ b/llms/mlx_lm/_version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.21.0" +__version__ = "0.21.5" diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 78a2e802..1fae76fa 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -409,8 +409,7 @@ def speculative_generate_step( for processor in logits_processors: logits = processor(tokens, logits) - logprobs = logits - mx.logsumexp(logits, keepdims=True) - logprobs = logprobs.squeeze(0) + logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) y = sampler(logprobs) return y, logprobs @@ -429,16 +428,14 @@ def speculative_generate_step( prev_tokens = ( mx.concat([prev_tokens, y]) if prev_tokens is not None else y ) - y, logprobs = _process_and_sample( - prev_tokens, logits[:, i : i + 1, :] - ) + y, logprobs = _process_and_sample(prev_tokens, logits[:, i, :]) out_y.append(y) out_logprobs.append(logprobs) return mx.concatenate(out_y, axis=0), mx.concatenate( out_logprobs, axis=0 ) else: - return _process_and_sample(None, logits) + return _process_and_sample(None, logits.squeeze(0)) def _prefill(model, cache, y): while y.size > prefill_step_size: @@ -476,13 +473,9 @@ def speculative_generate_step( num_draft = min(max_tokens - ntoks, num_draft_tokens) draft_tokens = _draft_generate(draft_y, num_draft) if prev_tokens is not None: - prev_tokens = prev_tokens[ - : prev_tokens.size - draft_y.size - num_draft + 1 - ] + prev_tokens = prev_tokens[: prev_tokens.size - y.size - num_draft + 1] y = mx.concatenate([y, draft_tokens]) - tokens, logprobs = _step(model, model_cache, y, num_draft + 1) - mx.eval(tokens, draft_tokens) draft_tokens = draft_tokens.tolist() tokens = tokens.tolist() @@ -514,8 +507,8 @@ def speculative_generate_step( [mx.array(draft_tokens[-1:], mx.uint32), draft_y] ) - if prev_tokens is not None and n < num_draft: - prev_tokens = prev_tokens[: -(num_draft - n)] + if prev_tokens is not None: + prev_tokens = prev_tokens[: -max(num_draft - n, 1)] _rewind_cache(num_draft, n) finally: _rewind_cache(num_draft, n)