mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-15 01:42:31 +08:00
[MLX LM] Sampler refactor + a few improvements (#1094)
* starting * refactor sampler/processor and a few improvements * fix stream * fix stream generate * fix eos handling in stream generate
This commit is contained in:
@@ -1,10 +1,83 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
def make_sampler(
|
||||
temp: float = 0.0,
|
||||
top_p: float = 0.0,
|
||||
min_p: float = 0.0,
|
||||
min_tokens_to_keep: int = 1,
|
||||
) -> Callable[mx.array, mx.array]:
|
||||
"""
|
||||
Make a sampler function for use with ``generate_step``.
|
||||
|
||||
Args:
|
||||
temp (float): The temperature for sampling, if 0 the argmax is used.
|
||||
Default: ``0``.
|
||||
top_p (float, optional): Nulceus sampling, higher means model considers
|
||||
more less likely words.
|
||||
min_p (float, optional): The minimum value (scaled by the top token's
|
||||
probability) that a token probability must have to be considered.
|
||||
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
|
||||
be filtered by min_p sampling.
|
||||
|
||||
Returns:
|
||||
Callable[mx.array, mx.array]:
|
||||
A sampler which takes log-probabilities and returns tokens.
|
||||
"""
|
||||
if temp == 0:
|
||||
return lambda x: mx.argmax(x, axis=-1)
|
||||
elif top_p > 0 and top_p < 1.0:
|
||||
return lambda x: top_p_sampling(x, top_p, temp)
|
||||
elif min_p != 0.0:
|
||||
return lambda x: min_p_sampling(x, min_p, min_tokens_to_keep, temp)
|
||||
else:
|
||||
return lambda x: categorical_sampling(x, temp)
|
||||
|
||||
|
||||
def make_logits_processors(
|
||||
logit_bias: Optional[Dict[int, float]] = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
repetition_context_size: Optional[int] = 20,
|
||||
):
|
||||
"""
|
||||
Make logits processors for use with ``generate_step``.
|
||||
|
||||
Args:
|
||||
repetition_penalty (float, optional): The penalty factor for repeating
|
||||
tokens.
|
||||
repetition_context_size (int, optional): The number of tokens to
|
||||
consider for repetition penalty. Default: ``20``.
|
||||
logit_bias (dictionary, optional): Additive logit bias.
|
||||
|
||||
Returns:
|
||||
List[Callable[[mx.array, mx.array], mx.array]]:
|
||||
A list of logits processors. Each processor in the list is a
|
||||
callable which takes an array of tokens and an array of logits
|
||||
and returns the updated logits.
|
||||
"""
|
||||
logits_processors = []
|
||||
if logit_bias:
|
||||
indices = mx.array(list(logit_bias.keys()))
|
||||
values = mx.array(list(logit_bias.values()))
|
||||
|
||||
def logit_bias_processor(_, logits):
|
||||
logits[:, indices] += values
|
||||
return logits
|
||||
|
||||
logits_processors.append(logit_bias_processor)
|
||||
|
||||
if repetition_penalty and repetition_penalty != 0.0:
|
||||
logits_processors.append(
|
||||
make_repetition_penalty(repetition_penalty, repetition_context_size)
|
||||
)
|
||||
return logits_processors
|
||||
|
||||
|
||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||
def min_p_sampling(
|
||||
logits: mx.array,
|
||||
@@ -100,3 +173,36 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
|
||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||
def categorical_sampling(logits, temp):
|
||||
return mx.random.categorical(logits * (1 / temp))
|
||||
|
||||
|
||||
def make_repetition_penalty(penalty: float, context_size: int = 20):
|
||||
"""
|
||||
Make repetition penalty processor.
|
||||
|
||||
Paper: https://arxiv.org/abs/1909.05858
|
||||
|
||||
Args:
|
||||
penalty (float): The repetition penalty factor to be applied.
|
||||
context_size (int): The number of previous tokens to use.
|
||||
Default: ``20``.
|
||||
|
||||
Returns:
|
||||
Callable[[mx.array, List[int]], mx.array]:
|
||||
The repetition penalty processor.
|
||||
"""
|
||||
if penalty < 0 or not isinstance(penalty, float):
|
||||
raise ValueError(f"penalty must be a non-negative float, got {penalty}")
|
||||
|
||||
def repetition_penalty_processor(tokens, logits):
|
||||
if len(tokens) > 0:
|
||||
tokens = tokens[-context_size:]
|
||||
selected_logits = logits[:, tokens]
|
||||
selected_logits = mx.where(
|
||||
selected_logits < 0,
|
||||
selected_logits * penalty,
|
||||
selected_logits / penalty,
|
||||
)
|
||||
logits[:, tokens] = selected_logits
|
||||
return logits
|
||||
|
||||
return repetition_penalty_processor
|
||||
|
||||
Reference in New Issue
Block a user