[MLX LM] Sampler refactor + a few improvements (#1094)

* starting

* refactor sampler/processor and a few improvements

* fix stream

* fix stream generate

* fix eos handling in stream generate
This commit is contained in:
Awni Hannun
2024-11-07 16:15:24 -08:00
committed by GitHub
parent ed9e81dd58
commit 657b4cc0aa
10 changed files with 259 additions and 239 deletions

View File

@@ -13,6 +13,8 @@ DEFAULT_PROMPT = "hello"
DEFAULT_MAX_TOKENS = 100
DEFAULT_TEMP = 0.0
DEFAULT_TOP_P = 1.0
DEFAULT_MIN_P = 0.0
DEFAULT_MIN_TOKENS_TO_KEEP = 1
DEFAULT_SEED = 0
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
DEFAULT_QUANTIZED_KV_START = 5000
@@ -52,6 +54,7 @@ def setup_arg_parser():
)
parser.add_argument(
"--prompt",
"-p",
default=DEFAULT_PROMPT,
help="Message to be processed by the model ('-' reads from stdin)",
)
@@ -68,6 +71,15 @@ def setup_arg_parser():
parser.add_argument(
"--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p"
)
parser.add_argument(
"--min-p", type=float, default=DEFAULT_MIN_P, help="Sampling min-p"
)
parser.add_argument(
"--min-tokens-to-keep",
type=float,
default=DEFAULT_MIN_TOKENS_TO_KEEP,
help="Minimum tokens to keep for min-p sampling.",
)
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed")
parser.add_argument(
"--ignore-chat-template",
@@ -247,6 +259,8 @@ def main():
formatter=formatter,
temp=args.temp,
top_p=args.top_p,
min_p=args.min_p,
min_tokens_to_keep=args.min_tokens_to_keep,
max_kv_size=args.max_kv_size,
prompt_cache=prompt_cache if using_cache else None,
kv_bits=args.kv_bits,