From 877d2a345b8119ad9ed50e2c273a5064ddd3b48c Mon Sep 17 00:00:00 2001 From: cavit99 <35897738+cavit99@users.noreply.github.com> Date: Thu, 6 Mar 2025 14:49:35 +0000 Subject: [PATCH] Change DEFAULT_SEED to None for stochastic generation by default (#1323) * Change DEFAULT_SEED to None for stochastic generation by default * Update llms/mlx_lm/chat.py * Update llms/mlx_lm/generate.py --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/chat.py | 12 +++++++++--- llms/mlx_lm/generate.py | 13 ++++++++++--- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/llms/mlx_lm/chat.py b/llms/mlx_lm/chat.py index 5c0b78db..d8e1ccb9 100644 --- a/llms/mlx_lm/chat.py +++ b/llms/mlx_lm/chat.py @@ -11,7 +11,7 @@ from .utils import load, stream_generate DEFAULT_TEMP = 0.0 DEFAULT_TOP_P = 1.0 -DEFAULT_SEED = 0 +DEFAULT_SEED = None DEFAULT_MAX_TOKENS = 256 DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" @@ -36,7 +36,12 @@ def setup_arg_parser(): parser.add_argument( "--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p" ) - parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed") + parser.add_argument( + "--seed", + type=int, + default=DEFAULT_SEED, + help="PRNG seed", + ) parser.add_argument( "--max-kv-size", type=int, @@ -57,7 +62,8 @@ def main(): parser = setup_arg_parser() args = parser.parse_args() - mx.random.seed(args.seed) + if args.seed is not None: + mx.random.seed(args.seed) model, tokenizer = load( args.model, diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index bd11dcf0..7d58da82 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -16,7 +16,7 @@ DEFAULT_TEMP = 0.0 DEFAULT_TOP_P = 1.0 DEFAULT_MIN_P = 0.0 DEFAULT_MIN_TOKENS_TO_KEEP = 1 -DEFAULT_SEED = 0 +DEFAULT_SEED = None DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" DEFAULT_QUANTIZED_KV_START = 5000 @@ -87,7 +87,12 @@ def setup_arg_parser(): default=DEFAULT_MIN_TOKENS_TO_KEEP, help="Minimum tokens to keep for min-p sampling.", ) - parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed") + parser.add_argument( + "--seed", + type=int, + default=DEFAULT_SEED, + help="PRNG seed", + ) parser.add_argument( "--ignore-chat-template", action="store_true", @@ -160,7 +165,9 @@ def setup_arg_parser(): def main(): parser = setup_arg_parser() args = parser.parse_args() - mx.random.seed(args.seed) + + if args.seed is not None: + mx.random.seed(args.seed) # Load the prompt cache and metadata if a cache file is provided using_cache = args.prompt_cache_file is not None