mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
make_sampler creates sampler chain with all sampling parameters (#1330)
* top_p refactor * top_k and min_p refactor * Create sampler chain * Remove unnecessary mx.where * Use mx.allclose
This commit is contained in:
@@ -35,14 +35,25 @@ def make_sampler(
|
||||
"""
|
||||
if temp == 0:
|
||||
return lambda x: mx.argmax(x, axis=-1)
|
||||
elif top_p > 0 and top_p < 1.0:
|
||||
return lambda x: top_p_sampling(x, top_p, temp)
|
||||
elif min_p != 0.0:
|
||||
return lambda x: min_p_sampling(x, min_p, min_tokens_to_keep, temp)
|
||||
elif top_k > 0:
|
||||
return lambda x: top_k_sampling(x, top_k, temp)
|
||||
else:
|
||||
return lambda x: categorical_sampling(x, temp)
|
||||
|
||||
# Create sampler chain
|
||||
sampling_methods = []
|
||||
if top_k > 0:
|
||||
sampling_methods.append(lambda x: apply_top_k(x, top_k))
|
||||
if top_p > 0 and top_p < 1.0:
|
||||
sampling_methods.append(lambda x: apply_top_p(x, top_p))
|
||||
if min_p != 0.0:
|
||||
sampling_methods.append(lambda x: apply_min_p(x, min_p, min_tokens_to_keep))
|
||||
|
||||
# Apply the sampling methods
|
||||
def sampler(logits):
|
||||
for method in sampling_methods:
|
||||
logits = method(logits)
|
||||
|
||||
# Return the sampled token
|
||||
return categorical_sampling(logits, temp)
|
||||
|
||||
return sampler
|
||||
|
||||
|
||||
def make_logits_processors(
|
||||
@@ -85,10 +96,9 @@ def make_logits_processors(
|
||||
|
||||
|
||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||
def top_k_sampling(
|
||||
def apply_top_k(
|
||||
logprobs: mx.array,
|
||||
top_k: int,
|
||||
temperature=1.0,
|
||||
) -> mx.array:
|
||||
"""
|
||||
Sample from only the top K tokens ranked by probability.
|
||||
@@ -103,20 +113,18 @@ def top_k_sampling(
|
||||
f"`top_k` has to be an integer in the (0, {vocab_size}] interval,"
|
||||
f" but is {top_k}."
|
||||
)
|
||||
logprobs = logprobs * (1 / temperature)
|
||||
mask_idx = mx.argpartition(-logprobs, kth=top_k - 1, axis=-1)[..., top_k:]
|
||||
masked_logprobs = mx.put_along_axis(
|
||||
logprobs, mask_idx, mx.array(-float("inf"), logprobs.dtype), axis=-1
|
||||
)
|
||||
return mx.random.categorical(masked_logprobs, axis=-1)
|
||||
return masked_logprobs
|
||||
|
||||
|
||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||
def min_p_sampling(
|
||||
def apply_min_p(
|
||||
logprobs: mx.array,
|
||||
min_p: float,
|
||||
min_tokens_to_keep: int = 1,
|
||||
temperature=1.0,
|
||||
) -> mx.array:
|
||||
"""
|
||||
Apply min-p sampling to the logprobs.
|
||||
@@ -144,8 +152,6 @@ def min_p_sampling(
|
||||
)
|
||||
# reference implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L531-L605
|
||||
|
||||
logprobs = logprobs * (1 / temperature)
|
||||
|
||||
# Indices sorted in decreasing order
|
||||
sorted_indices = mx.argsort(-logprobs, axis=-1)
|
||||
sorted_logprobs = mx.take_along_axis(logprobs, sorted_indices, axis=-1)
|
||||
@@ -163,25 +169,31 @@ def min_p_sampling(
|
||||
# Create pool of tokens with probability less than scaled min_p
|
||||
selected_logprobs = mx.where(tokens_to_remove, -float("inf"), sorted_logprobs)
|
||||
|
||||
# Return sampled tokens
|
||||
sorted_tokens = mx.random.categorical(selected_logprobs, axis=-1)[:, None]
|
||||
return mx.take_along_axis(sorted_indices, sorted_tokens, axis=-1).squeeze(1)
|
||||
# Create a mapping to rearrange back to original indices
|
||||
# Use argsort of sorted_indices to get the inverse permutation
|
||||
inverse_indices = mx.argsort(sorted_indices, axis=-1)
|
||||
|
||||
# Rearrange selected_logprobs back to original order
|
||||
original_order_logprobs = mx.take_along_axis(
|
||||
selected_logprobs, inverse_indices, axis=-1
|
||||
)
|
||||
|
||||
return original_order_logprobs
|
||||
|
||||
|
||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||
def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.array:
|
||||
def apply_top_p(logits: mx.array, top_p: float) -> mx.array:
|
||||
"""
|
||||
Apply top-p (nucleus) sampling to logits.
|
||||
|
||||
Args:
|
||||
logits: The logits from the model's output.
|
||||
top_p: The cumulative probability threshold for top-p filtering.
|
||||
temperature: Temperature parameter for softmax distribution reshaping.
|
||||
Returns:
|
||||
token selected based on the top-p criterion.
|
||||
"""
|
||||
# referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460
|
||||
probs = mx.softmax(logits * (1 / temperature), axis=-1)
|
||||
probs = mx.softmax(logits, axis=-1)
|
||||
|
||||
# sort probs in ascending order
|
||||
sorted_indices = mx.argsort(probs, axis=-1)
|
||||
@@ -196,8 +208,15 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
|
||||
0,
|
||||
)
|
||||
|
||||
sorted_tokens = mx.random.categorical(mx.log(top_probs), axis=-1)[:, None]
|
||||
return mx.take_along_axis(sorted_indices, sorted_tokens, axis=-1).squeeze(1)
|
||||
# Create a mapping to rearrange back to original indices
|
||||
# Use argsort of sorted_indices to get the inverse permutation
|
||||
inverse_indices = mx.argsort(sorted_indices, axis=-1)
|
||||
|
||||
# Rearrange top_probs back to original order
|
||||
original_order_probs = mx.take_along_axis(top_probs, inverse_indices, axis=-1)
|
||||
|
||||
# Convert back to logits and return
|
||||
return mx.log(original_order_probs)
|
||||
|
||||
|
||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||
|
||||
Reference in New Issue
Block a user