mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-15 01:42:31 +08:00
batched min p and fix spec gen sampling (#1222)
This commit is contained in:
@@ -398,8 +398,9 @@ def speculative_generate_step(
|
||||
quantize_cache_fn(cache)
|
||||
|
||||
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
||||
y = sampler(logprobs).squeeze(0)
|
||||
return y, logprobs.squeeze(0)
|
||||
logprobs = logprobs.squeeze(0)
|
||||
y = sampler(logprobs)
|
||||
return y, logprobs
|
||||
|
||||
def _prefill(model, cache, y):
|
||||
while y.size > prefill_step_size:
|
||||
|
||||
Reference in New Issue
Block a user