mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-15 01:42:31 +08:00
Add top-p sampling for text generation (#486)
This commit is contained in:
@@ -111,6 +111,7 @@ def generate_step(
|
||||
temp: 0.0,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
repetition_context_size: Optional[int] = 20,
|
||||
top_p: float = 1.0,
|
||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||
"""
|
||||
A generator producing text based on the given prompt from the model.
|
||||
@@ -133,7 +134,26 @@ def generate_step(
|
||||
if temp == 0:
|
||||
token = mx.argmax(logits, axis=-1)
|
||||
else:
|
||||
token = mx.random.categorical(logits * (1 / temp))
|
||||
if top_p > 0 and top_p < 1.0:
|
||||
if (
|
||||
logits.dtype == mx.bfloat16
|
||||
): # workdaround for unable to load kernel contiguous_scan_inclusive_sum_bfloat16_bfloat16
|
||||
logits = logits.astype(mx.float32)
|
||||
probs = mx.softmax(logits / temp, axis=-1)
|
||||
|
||||
sorted_probs = mx.sort(probs)[::-1]
|
||||
sorted_indices = mx.argsort(probs)[::-1]
|
||||
cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
|
||||
|
||||
top_probs = mx.where(
|
||||
cumulative_probs > 1 - top_p,
|
||||
sorted_probs,
|
||||
mx.zeros_like(sorted_probs),
|
||||
)
|
||||
sorted_token = mx.random.categorical(mx.log(top_probs))
|
||||
token = sorted_indices.squeeze(0)[sorted_token]
|
||||
else:
|
||||
token = mx.random.categorical(logits * (1 / temp))
|
||||
|
||||
prob = softmax_logits[0, token]
|
||||
return token, prob
|
||||
@@ -182,6 +202,7 @@ def generate(
|
||||
formatter: Callable = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
repetition_context_size: Optional[int] = None,
|
||||
top_p: float = 1.0,
|
||||
) -> str:
|
||||
"""
|
||||
Generate text from the model.
|
||||
@@ -218,6 +239,7 @@ def generate(
|
||||
temp,
|
||||
repetition_penalty,
|
||||
repetition_context_size,
|
||||
top_p,
|
||||
),
|
||||
range(max_tokens),
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user