mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Change DEFAULT_SEED to None for stochastic generation by default (#1323)
* Change DEFAULT_SEED to None for stochastic generation by default * Update llms/mlx_lm/chat.py * Update llms/mlx_lm/generate.py --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
parent
32d10036de
commit
877d2a345b
@ -11,7 +11,7 @@ from .utils import load, stream_generate
|
|||||||
|
|
||||||
DEFAULT_TEMP = 0.0
|
DEFAULT_TEMP = 0.0
|
||||||
DEFAULT_TOP_P = 1.0
|
DEFAULT_TOP_P = 1.0
|
||||||
DEFAULT_SEED = 0
|
DEFAULT_SEED = None
|
||||||
DEFAULT_MAX_TOKENS = 256
|
DEFAULT_MAX_TOKENS = 256
|
||||||
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
||||||
|
|
||||||
@ -36,7 +36,12 @@ def setup_arg_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p"
|
"--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p"
|
||||||
)
|
)
|
||||||
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed")
|
parser.add_argument(
|
||||||
|
"--seed",
|
||||||
|
type=int,
|
||||||
|
default=DEFAULT_SEED,
|
||||||
|
help="PRNG seed",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-kv-size",
|
"--max-kv-size",
|
||||||
type=int,
|
type=int,
|
||||||
@ -57,6 +62,7 @@ def main():
|
|||||||
parser = setup_arg_parser()
|
parser = setup_arg_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.seed is not None:
|
||||||
mx.random.seed(args.seed)
|
mx.random.seed(args.seed)
|
||||||
|
|
||||||
model, tokenizer = load(
|
model, tokenizer = load(
|
||||||
|
@ -16,7 +16,7 @@ DEFAULT_TEMP = 0.0
|
|||||||
DEFAULT_TOP_P = 1.0
|
DEFAULT_TOP_P = 1.0
|
||||||
DEFAULT_MIN_P = 0.0
|
DEFAULT_MIN_P = 0.0
|
||||||
DEFAULT_MIN_TOKENS_TO_KEEP = 1
|
DEFAULT_MIN_TOKENS_TO_KEEP = 1
|
||||||
DEFAULT_SEED = 0
|
DEFAULT_SEED = None
|
||||||
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
||||||
DEFAULT_QUANTIZED_KV_START = 5000
|
DEFAULT_QUANTIZED_KV_START = 5000
|
||||||
|
|
||||||
@ -87,7 +87,12 @@ def setup_arg_parser():
|
|||||||
default=DEFAULT_MIN_TOKENS_TO_KEEP,
|
default=DEFAULT_MIN_TOKENS_TO_KEEP,
|
||||||
help="Minimum tokens to keep for min-p sampling.",
|
help="Minimum tokens to keep for min-p sampling.",
|
||||||
)
|
)
|
||||||
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed")
|
parser.add_argument(
|
||||||
|
"--seed",
|
||||||
|
type=int,
|
||||||
|
default=DEFAULT_SEED,
|
||||||
|
help="PRNG seed",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ignore-chat-template",
|
"--ignore-chat-template",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
@ -160,6 +165,8 @@ def setup_arg_parser():
|
|||||||
def main():
|
def main():
|
||||||
parser = setup_arg_parser()
|
parser = setup_arg_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.seed is not None:
|
||||||
mx.random.seed(args.seed)
|
mx.random.seed(args.seed)
|
||||||
|
|
||||||
# Load the prompt cache and metadata if a cache file is provided
|
# Load the prompt cache and metadata if a cache file is provided
|
||||||
|
Loading…
Reference in New Issue
Block a user