Add top-p sampling for text generation (#486)

This commit is contained in:
peterjc123
2024-02-26 22:18:11 +08:00
committed by GitHub
parent 47dd6bd17f
commit ccb278bcbd
2 changed files with 28 additions and 2 deletions

View File

@@ -8,6 +8,7 @@ DEFAULT_MODEL_PATH = "mlx_model"
DEFAULT_PROMPT = "hello"
DEFAULT_MAX_TOKENS = 100
DEFAULT_TEMP = 0.6
DEFAULT_TOP_P = 1.0
DEFAULT_SEED = 0
@@ -44,6 +45,9 @@ def setup_arg_parser():
parser.add_argument(
"--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature"
)
parser.add_argument(
"--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p"
)
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed")
parser.add_argument(
"--ignore-chat-template",
@@ -109,7 +113,7 @@ def main(args):
formatter = colorprint_by_t0 if args.colorize else None
generate(
model, tokenizer, prompt, args.temp, args.max_tokens, True, formatter=formatter
model, tokenizer, prompt, args.temp, args.max_tokens, True, formatter=formatter, top_p=args.top_p
)