Generation refactor: part 2 (#1099)

* unify with stream_generate

* fixes

* nit

* some cleanup, warnings, tests

* fix test + faster min p + test

* version
This commit is contained in:
Awni Hannun
2024-11-23 11:47:06 -08:00
committed by GitHub
parent 004eb4cc9d
commit 0f135396ae
13 changed files with 184 additions and 197 deletions

View File

@@ -7,6 +7,7 @@ import sys
import mlx.core as mx
from .models.cache import QuantizedKVCache, load_prompt_cache
from .sample_utils import make_sampler
from .utils import generate, load
DEFAULT_PROMPT = "hello"
@@ -97,11 +98,6 @@ def setup_arg_parser():
default=True,
help="Log verbose output when 'True' or 'T' or only print the response when 'False' or 'F'",
)
parser.add_argument(
"--colorize",
action="store_true",
help="Colorize output based on T[0] probability",
)
parser.add_argument(
"--max-kv-size",
type=int,
@@ -137,33 +133,6 @@ def setup_arg_parser():
return parser
def colorprint(color, s):
color_codes = {
"black": 30,
"red": 31,
"green": 32,
"yellow": 33,
"blue": 34,
"magenta": 35,
"cyan": 36,
"white": 39,
}
ccode = color_codes.get(color, 30)
print(f"\033[1m\033[{ccode}m{s}\033[0m", end="", flush=True)
def colorprint_by_t0(s, t0):
if t0 > 0.95:
color = "white"
elif t0 > 0.70:
color = "green"
elif t0 > 0.30:
color = "yellow"
else:
color = "red"
colorprint(color, s)
def main():
parser = setup_arg_parser()
args = parser.parse_args()
@@ -250,21 +219,14 @@ def main():
else:
prompt = args.prompt
if args.colorize and not args.verbose:
raise ValueError("Cannot use --colorize with --verbose=False")
formatter = colorprint_by_t0 if args.colorize else None
sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
response = generate(
model,
tokenizer,
prompt,
args.max_tokens,
max_tokens=args.max_tokens,
verbose=args.verbose,
formatter=formatter,
temp=args.temp,
top_p=args.top_p,
min_p=args.min_p,
min_tokens_to_keep=args.min_tokens_to_keep,
sampler=sampler,
max_kv_size=args.max_kv_size,
prompt_cache=prompt_cache if using_cache else None,
kv_bits=args.kv_bits,