mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
batched min p and fix spec gen sampling (#1222)
This commit is contained in:
parent
77faa14ba4
commit
f44a52e2dc
@ -147,11 +147,11 @@ def min_p_sampling(
|
|||||||
logprobs = logprobs * (1 / temperature)
|
logprobs = logprobs * (1 / temperature)
|
||||||
|
|
||||||
# Indices sorted in decreasing order
|
# Indices sorted in decreasing order
|
||||||
sorted_indices = mx.argsort(-logprobs).squeeze(0)
|
sorted_indices = mx.argsort(-logprobs, axis=-1)
|
||||||
sorted_logprobs = logprobs[..., sorted_indices]
|
sorted_logprobs = mx.take_along_axis(logprobs, sorted_indices, axis=-1)
|
||||||
|
|
||||||
# Top probability
|
# Top probability
|
||||||
top_logprobs = logprobs[..., sorted_indices[0]]
|
top_logprobs = sorted_logprobs[:, 0:1]
|
||||||
|
|
||||||
# Calculate the min_p threshold
|
# Calculate the min_p threshold
|
||||||
scaled_min_p = top_logprobs + math.log(min_p)
|
scaled_min_p = top_logprobs + math.log(min_p)
|
||||||
@ -163,9 +163,9 @@ def min_p_sampling(
|
|||||||
# Create pool of tokens with probability less than scaled min_p
|
# Create pool of tokens with probability less than scaled min_p
|
||||||
selected_logprobs = mx.where(tokens_to_remove, -float("inf"), sorted_logprobs)
|
selected_logprobs = mx.where(tokens_to_remove, -float("inf"), sorted_logprobs)
|
||||||
|
|
||||||
# Return sampled token
|
# Return sampled tokens
|
||||||
sorted_token = mx.random.categorical(selected_logprobs)
|
sorted_tokens = mx.random.categorical(selected_logprobs, axis=-1)[:, None]
|
||||||
return sorted_indices[sorted_token]
|
return mx.take_along_axis(sorted_indices, sorted_tokens, axis=-1).squeeze(1)
|
||||||
|
|
||||||
|
|
||||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||||
@ -185,7 +185,7 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
|
|||||||
|
|
||||||
# sort probs in ascending order
|
# sort probs in ascending order
|
||||||
sorted_indices = mx.argsort(probs, axis=-1)
|
sorted_indices = mx.argsort(probs, axis=-1)
|
||||||
sorted_probs = probs[..., sorted_indices.squeeze(0)]
|
sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=-1)
|
||||||
|
|
||||||
cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
|
cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
|
||||||
|
|
||||||
@ -196,10 +196,8 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
|
|||||||
0,
|
0,
|
||||||
)
|
)
|
||||||
|
|
||||||
sorted_token = mx.random.categorical(mx.log(top_probs))
|
sorted_tokens = mx.random.categorical(mx.log(top_probs), axis=-1)[:, None]
|
||||||
token = sorted_indices.squeeze(0)[sorted_token]
|
return mx.take_along_axis(sorted_indices, sorted_tokens, axis=-1).squeeze(1)
|
||||||
|
|
||||||
return token
|
|
||||||
|
|
||||||
|
|
||||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||||
|
@ -398,8 +398,9 @@ def speculative_generate_step(
|
|||||||
quantize_cache_fn(cache)
|
quantize_cache_fn(cache)
|
||||||
|
|
||||||
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
||||||
y = sampler(logprobs).squeeze(0)
|
logprobs = logprobs.squeeze(0)
|
||||||
return y, logprobs.squeeze(0)
|
y = sampler(logprobs)
|
||||||
|
return y, logprobs
|
||||||
|
|
||||||
def _prefill(model, cache, y):
|
def _prefill(model, cache, y):
|
||||||
while y.size > prefill_step_size:
|
while y.size > prefill_step_size:
|
||||||
|
@ -28,6 +28,12 @@ class TestSampleUtils(unittest.TestCase):
|
|||||||
token = top_p_sampling(logits, 0.95, temperature).item()
|
token = top_p_sampling(logits, 0.95, temperature).item()
|
||||||
self.assertTrue(token in (1, 2, 3))
|
self.assertTrue(token in (1, 2, 3))
|
||||||
|
|
||||||
|
# Batch mode works
|
||||||
|
probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]])
|
||||||
|
logits = mx.log(probs)
|
||||||
|
tokens = top_p_sampling(logits, 0.5, temperature)
|
||||||
|
self.assertEqual(tokens.tolist(), [0, 1])
|
||||||
|
|
||||||
def test_min_p_sampling(self):
|
def test_min_p_sampling(self):
|
||||||
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
|
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
|
||||||
logits = mx.log(probs)
|
logits = mx.log(probs)
|
||||||
@ -42,6 +48,12 @@ class TestSampleUtils(unittest.TestCase):
|
|||||||
token = min_p_sampling(logits, 0.05)
|
token = min_p_sampling(logits, 0.05)
|
||||||
self.assertTrue(token in (0, 3))
|
self.assertTrue(token in (0, 3))
|
||||||
|
|
||||||
|
# Batch mode works
|
||||||
|
probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]])
|
||||||
|
logits = mx.log(probs)
|
||||||
|
tokens = min_p_sampling(logits, 0.7)
|
||||||
|
self.assertEqual(tokens.tolist(), [0, 1])
|
||||||
|
|
||||||
def test_top_k_sampling(self):
|
def test_top_k_sampling(self):
|
||||||
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
|
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
|
||||||
logits = mx.log(probs)
|
logits = mx.log(probs)
|
||||||
|
Loading…
Reference in New Issue
Block a user