mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-09 10:26:38 +08:00
Add top-p sampling for text generation (#486)
This commit is contained in:
parent
47dd6bd17f
commit
ccb278bcbd
@ -8,6 +8,7 @@ DEFAULT_MODEL_PATH = "mlx_model"
|
|||||||
DEFAULT_PROMPT = "hello"
|
DEFAULT_PROMPT = "hello"
|
||||||
DEFAULT_MAX_TOKENS = 100
|
DEFAULT_MAX_TOKENS = 100
|
||||||
DEFAULT_TEMP = 0.6
|
DEFAULT_TEMP = 0.6
|
||||||
|
DEFAULT_TOP_P = 1.0
|
||||||
DEFAULT_SEED = 0
|
DEFAULT_SEED = 0
|
||||||
|
|
||||||
|
|
||||||
@ -44,6 +45,9 @@ def setup_arg_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature"
|
"--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p"
|
||||||
|
)
|
||||||
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed")
|
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ignore-chat-template",
|
"--ignore-chat-template",
|
||||||
@ -109,7 +113,7 @@ def main(args):
|
|||||||
formatter = colorprint_by_t0 if args.colorize else None
|
formatter = colorprint_by_t0 if args.colorize else None
|
||||||
|
|
||||||
generate(
|
generate(
|
||||||
model, tokenizer, prompt, args.temp, args.max_tokens, True, formatter=formatter
|
model, tokenizer, prompt, args.temp, args.max_tokens, True, formatter=formatter, top_p=args.top_p
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -111,6 +111,7 @@ def generate_step(
|
|||||||
temp: 0.0,
|
temp: 0.0,
|
||||||
repetition_penalty: Optional[float] = None,
|
repetition_penalty: Optional[float] = None,
|
||||||
repetition_context_size: Optional[int] = 20,
|
repetition_context_size: Optional[int] = 20,
|
||||||
|
top_p: float = 1.0,
|
||||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||||
"""
|
"""
|
||||||
A generator producing text based on the given prompt from the model.
|
A generator producing text based on the given prompt from the model.
|
||||||
@ -133,7 +134,26 @@ def generate_step(
|
|||||||
if temp == 0:
|
if temp == 0:
|
||||||
token = mx.argmax(logits, axis=-1)
|
token = mx.argmax(logits, axis=-1)
|
||||||
else:
|
else:
|
||||||
token = mx.random.categorical(logits * (1 / temp))
|
if top_p > 0 and top_p < 1.0:
|
||||||
|
if (
|
||||||
|
logits.dtype == mx.bfloat16
|
||||||
|
): # workdaround for unable to load kernel contiguous_scan_inclusive_sum_bfloat16_bfloat16
|
||||||
|
logits = logits.astype(mx.float32)
|
||||||
|
probs = mx.softmax(logits / temp, axis=-1)
|
||||||
|
|
||||||
|
sorted_probs = mx.sort(probs)[::-1]
|
||||||
|
sorted_indices = mx.argsort(probs)[::-1]
|
||||||
|
cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
|
||||||
|
|
||||||
|
top_probs = mx.where(
|
||||||
|
cumulative_probs > 1 - top_p,
|
||||||
|
sorted_probs,
|
||||||
|
mx.zeros_like(sorted_probs),
|
||||||
|
)
|
||||||
|
sorted_token = mx.random.categorical(mx.log(top_probs))
|
||||||
|
token = sorted_indices.squeeze(0)[sorted_token]
|
||||||
|
else:
|
||||||
|
token = mx.random.categorical(logits * (1 / temp))
|
||||||
|
|
||||||
prob = softmax_logits[0, token]
|
prob = softmax_logits[0, token]
|
||||||
return token, prob
|
return token, prob
|
||||||
@ -182,6 +202,7 @@ def generate(
|
|||||||
formatter: Callable = None,
|
formatter: Callable = None,
|
||||||
repetition_penalty: Optional[float] = None,
|
repetition_penalty: Optional[float] = None,
|
||||||
repetition_context_size: Optional[int] = None,
|
repetition_context_size: Optional[int] = None,
|
||||||
|
top_p: float = 1.0,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Generate text from the model.
|
Generate text from the model.
|
||||||
@ -218,6 +239,7 @@ def generate(
|
|||||||
temp,
|
temp,
|
||||||
repetition_penalty,
|
repetition_penalty,
|
||||||
repetition_context_size,
|
repetition_context_size,
|
||||||
|
top_p,
|
||||||
),
|
),
|
||||||
range(max_tokens),
|
range(max_tokens),
|
||||||
):
|
):
|
||||||
|
Loading…
Reference in New Issue
Block a user