batched min p and fix spec gen sampling (#1222)

This commit is contained in:
Awni Hannun 2025-01-27 15:40:31 -08:00 committed by GitHub
parent 77faa14ba4
commit f44a52e2dc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 24 additions and 13 deletions

View File

@ -147,11 +147,11 @@ def min_p_sampling(
logprobs = logprobs * (1 / temperature)
# Indices sorted in decreasing order
sorted_indices = mx.argsort(-logprobs).squeeze(0)
sorted_logprobs = logprobs[..., sorted_indices]
sorted_indices = mx.argsort(-logprobs, axis=-1)
sorted_logprobs = mx.take_along_axis(logprobs, sorted_indices, axis=-1)
# Top probability
top_logprobs = logprobs[..., sorted_indices[0]]
top_logprobs = sorted_logprobs[:, 0:1]
# Calculate the min_p threshold
scaled_min_p = top_logprobs + math.log(min_p)
@ -163,9 +163,9 @@ def min_p_sampling(
# Create pool of tokens with probability less than scaled min_p
selected_logprobs = mx.where(tokens_to_remove, -float("inf"), sorted_logprobs)
# Return sampled token
sorted_token = mx.random.categorical(selected_logprobs)
return sorted_indices[sorted_token]
# Return sampled tokens
sorted_tokens = mx.random.categorical(selected_logprobs, axis=-1)[:, None]
return mx.take_along_axis(sorted_indices, sorted_tokens, axis=-1).squeeze(1)
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
@ -185,7 +185,7 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
# sort probs in ascending order
sorted_indices = mx.argsort(probs, axis=-1)
sorted_probs = probs[..., sorted_indices.squeeze(0)]
sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=-1)
cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
@ -196,10 +196,8 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
0,
)
sorted_token = mx.random.categorical(mx.log(top_probs))
token = sorted_indices.squeeze(0)[sorted_token]
return token
sorted_tokens = mx.random.categorical(mx.log(top_probs), axis=-1)[:, None]
return mx.take_along_axis(sorted_indices, sorted_tokens, axis=-1).squeeze(1)
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)

View File

@ -398,8 +398,9 @@ def speculative_generate_step(
quantize_cache_fn(cache)
logprobs = logits - mx.logsumexp(logits, keepdims=True)
y = sampler(logprobs).squeeze(0)
return y, logprobs.squeeze(0)
logprobs = logprobs.squeeze(0)
y = sampler(logprobs)
return y, logprobs
def _prefill(model, cache, y):
while y.size > prefill_step_size:

View File

@ -28,6 +28,12 @@ class TestSampleUtils(unittest.TestCase):
token = top_p_sampling(logits, 0.95, temperature).item()
self.assertTrue(token in (1, 2, 3))
# Batch mode works
probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]])
logits = mx.log(probs)
tokens = top_p_sampling(logits, 0.5, temperature)
self.assertEqual(tokens.tolist(), [0, 1])
def test_min_p_sampling(self):
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
logits = mx.log(probs)
@ -42,6 +48,12 @@ class TestSampleUtils(unittest.TestCase):
token = min_p_sampling(logits, 0.05)
self.assertTrue(token in (0, 3))
# Batch mode works
probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]])
logits = mx.log(probs)
tokens = min_p_sampling(logits, 0.7)
self.assertEqual(tokens.tolist(), [0, 1])
def test_top_k_sampling(self):
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
logits = mx.log(probs)