mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Add top-p sampling for text generation (#486)
This commit is contained in:
@@ -8,6 +8,7 @@ DEFAULT_MODEL_PATH = "mlx_model"
|
||||
DEFAULT_PROMPT = "hello"
|
||||
DEFAULT_MAX_TOKENS = 100
|
||||
DEFAULT_TEMP = 0.6
|
||||
DEFAULT_TOP_P = 1.0
|
||||
DEFAULT_SEED = 0
|
||||
|
||||
|
||||
@@ -44,6 +45,9 @@ def setup_arg_parser():
|
||||
parser.add_argument(
|
||||
"--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed")
|
||||
parser.add_argument(
|
||||
"--ignore-chat-template",
|
||||
@@ -109,7 +113,7 @@ def main(args):
|
||||
formatter = colorprint_by_t0 if args.colorize else None
|
||||
|
||||
generate(
|
||||
model, tokenizer, prompt, args.temp, args.max_tokens, True, formatter=formatter
|
||||
model, tokenizer, prompt, args.temp, args.max_tokens, True, formatter=formatter, top_p=args.top_p
|
||||
)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user