batched min p and fix spec gen sampling (#1222)

This commit is contained in:
Awni Hannun
2025-01-27 15:40:31 -08:00
committed by GitHub
parent 77faa14ba4
commit f44a52e2dc
3 changed files with 24 additions and 13 deletions

View File

@@ -398,8 +398,9 @@ def speculative_generate_step(
quantize_cache_fn(cache)
logprobs = logits - mx.logsumexp(logits, keepdims=True)
y = sampler(logprobs).squeeze(0)
return y, logprobs.squeeze(0)
logprobs = logprobs.squeeze(0)
y = sampler(logprobs)
return y, logprobs
def _prefill(model, cache, y):
while y.size > prefill_step_size: