mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
Generation refactor: part 2 (#1099)
* unify with stream_generate * fixes * nit * some cleanup, warnings, tests * fix test + faster min p + test * version
This commit is contained in:
@@ -7,6 +7,7 @@ import sys
|
||||
import mlx.core as mx
|
||||
|
||||
from .models.cache import QuantizedKVCache, load_prompt_cache
|
||||
from .sample_utils import make_sampler
|
||||
from .utils import generate, load
|
||||
|
||||
DEFAULT_PROMPT = "hello"
|
||||
@@ -97,11 +98,6 @@ def setup_arg_parser():
|
||||
default=True,
|
||||
help="Log verbose output when 'True' or 'T' or only print the response when 'False' or 'F'",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--colorize",
|
||||
action="store_true",
|
||||
help="Colorize output based on T[0] probability",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-kv-size",
|
||||
type=int,
|
||||
@@ -137,33 +133,6 @@ def setup_arg_parser():
|
||||
return parser
|
||||
|
||||
|
||||
def colorprint(color, s):
|
||||
color_codes = {
|
||||
"black": 30,
|
||||
"red": 31,
|
||||
"green": 32,
|
||||
"yellow": 33,
|
||||
"blue": 34,
|
||||
"magenta": 35,
|
||||
"cyan": 36,
|
||||
"white": 39,
|
||||
}
|
||||
ccode = color_codes.get(color, 30)
|
||||
print(f"\033[1m\033[{ccode}m{s}\033[0m", end="", flush=True)
|
||||
|
||||
|
||||
def colorprint_by_t0(s, t0):
|
||||
if t0 > 0.95:
|
||||
color = "white"
|
||||
elif t0 > 0.70:
|
||||
color = "green"
|
||||
elif t0 > 0.30:
|
||||
color = "yellow"
|
||||
else:
|
||||
color = "red"
|
||||
colorprint(color, s)
|
||||
|
||||
|
||||
def main():
|
||||
parser = setup_arg_parser()
|
||||
args = parser.parse_args()
|
||||
@@ -250,21 +219,14 @@ def main():
|
||||
else:
|
||||
prompt = args.prompt
|
||||
|
||||
if args.colorize and not args.verbose:
|
||||
raise ValueError("Cannot use --colorize with --verbose=False")
|
||||
formatter = colorprint_by_t0 if args.colorize else None
|
||||
|
||||
sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
|
||||
response = generate(
|
||||
model,
|
||||
tokenizer,
|
||||
prompt,
|
||||
args.max_tokens,
|
||||
max_tokens=args.max_tokens,
|
||||
verbose=args.verbose,
|
||||
formatter=formatter,
|
||||
temp=args.temp,
|
||||
top_p=args.top_p,
|
||||
min_p=args.min_p,
|
||||
min_tokens_to_keep=args.min_tokens_to_keep,
|
||||
sampler=sampler,
|
||||
max_kv_size=args.max_kv_size,
|
||||
prompt_cache=prompt_cache if using_cache else None,
|
||||
kv_bits=args.kv_bits,
|
||||
|
Reference in New Issue
Block a user