mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
batched min p and fix spec gen sampling (#1222)
This commit is contained in:
@@ -147,11 +147,11 @@ def min_p_sampling(
|
||||
logprobs = logprobs * (1 / temperature)
|
||||
|
||||
# Indices sorted in decreasing order
|
||||
sorted_indices = mx.argsort(-logprobs).squeeze(0)
|
||||
sorted_logprobs = logprobs[..., sorted_indices]
|
||||
sorted_indices = mx.argsort(-logprobs, axis=-1)
|
||||
sorted_logprobs = mx.take_along_axis(logprobs, sorted_indices, axis=-1)
|
||||
|
||||
# Top probability
|
||||
top_logprobs = logprobs[..., sorted_indices[0]]
|
||||
top_logprobs = sorted_logprobs[:, 0:1]
|
||||
|
||||
# Calculate the min_p threshold
|
||||
scaled_min_p = top_logprobs + math.log(min_p)
|
||||
@@ -163,9 +163,9 @@ def min_p_sampling(
|
||||
# Create pool of tokens with probability less than scaled min_p
|
||||
selected_logprobs = mx.where(tokens_to_remove, -float("inf"), sorted_logprobs)
|
||||
|
||||
# Return sampled token
|
||||
sorted_token = mx.random.categorical(selected_logprobs)
|
||||
return sorted_indices[sorted_token]
|
||||
# Return sampled tokens
|
||||
sorted_tokens = mx.random.categorical(selected_logprobs, axis=-1)[:, None]
|
||||
return mx.take_along_axis(sorted_indices, sorted_tokens, axis=-1).squeeze(1)
|
||||
|
||||
|
||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||
@@ -185,7 +185,7 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
|
||||
|
||||
# sort probs in ascending order
|
||||
sorted_indices = mx.argsort(probs, axis=-1)
|
||||
sorted_probs = probs[..., sorted_indices.squeeze(0)]
|
||||
sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=-1)
|
||||
|
||||
cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
|
||||
|
||||
@@ -196,10 +196,8 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
|
||||
0,
|
||||
)
|
||||
|
||||
sorted_token = mx.random.categorical(mx.log(top_probs))
|
||||
token = sorted_indices.squeeze(0)[sorted_token]
|
||||
|
||||
return token
|
||||
sorted_tokens = mx.random.categorical(mx.log(top_probs), axis=-1)[:, None]
|
||||
return mx.take_along_axis(sorted_indices, sorted_tokens, axis=-1).squeeze(1)
|
||||
|
||||
|
||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||
|
@@ -398,8 +398,9 @@ def speculative_generate_step(
|
||||
quantize_cache_fn(cache)
|
||||
|
||||
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
||||
y = sampler(logprobs).squeeze(0)
|
||||
return y, logprobs.squeeze(0)
|
||||
logprobs = logprobs.squeeze(0)
|
||||
y = sampler(logprobs)
|
||||
return y, logprobs
|
||||
|
||||
def _prefill(model, cache, y):
|
||||
while y.size > prefill_step_size:
|
||||
|
Reference in New Issue
Block a user