mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 19:06:37 +08:00
unify with stream_generate
This commit is contained in:
parent
004eb4cc9d
commit
431988721f
@ -97,11 +97,6 @@ def setup_arg_parser():
|
|||||||
default=True,
|
default=True,
|
||||||
help="Log verbose output when 'True' or 'T' or only print the response when 'False' or 'F'",
|
help="Log verbose output when 'True' or 'T' or only print the response when 'False' or 'F'",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--colorize",
|
|
||||||
action="store_true",
|
|
||||||
help="Colorize output based on T[0] probability",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-kv-size",
|
"--max-kv-size",
|
||||||
type=int,
|
type=int,
|
||||||
@ -137,33 +132,6 @@ def setup_arg_parser():
|
|||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def colorprint(color, s):
|
|
||||||
color_codes = {
|
|
||||||
"black": 30,
|
|
||||||
"red": 31,
|
|
||||||
"green": 32,
|
|
||||||
"yellow": 33,
|
|
||||||
"blue": 34,
|
|
||||||
"magenta": 35,
|
|
||||||
"cyan": 36,
|
|
||||||
"white": 39,
|
|
||||||
}
|
|
||||||
ccode = color_codes.get(color, 30)
|
|
||||||
print(f"\033[1m\033[{ccode}m{s}\033[0m", end="", flush=True)
|
|
||||||
|
|
||||||
|
|
||||||
def colorprint_by_t0(s, t0):
|
|
||||||
if t0 > 0.95:
|
|
||||||
color = "white"
|
|
||||||
elif t0 > 0.70:
|
|
||||||
color = "green"
|
|
||||||
elif t0 > 0.30:
|
|
||||||
color = "yellow"
|
|
||||||
else:
|
|
||||||
color = "red"
|
|
||||||
colorprint(color, s)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = setup_arg_parser()
|
parser = setup_arg_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@ -250,17 +218,12 @@ def main():
|
|||||||
else:
|
else:
|
||||||
prompt = args.prompt
|
prompt = args.prompt
|
||||||
|
|
||||||
if args.colorize and not args.verbose:
|
|
||||||
raise ValueError("Cannot use --colorize with --verbose=False")
|
|
||||||
formatter = colorprint_by_t0 if args.colorize else None
|
|
||||||
|
|
||||||
response = generate(
|
response = generate(
|
||||||
model,
|
model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
prompt,
|
prompt,
|
||||||
args.max_tokens,
|
max_tokens=args.max_tokens,
|
||||||
verbose=args.verbose,
|
verbose=args.verbose,
|
||||||
formatter=formatter,
|
|
||||||
temp=args.temp,
|
temp=args.temp,
|
||||||
top_p=args.top_p,
|
top_p=args.top_p,
|
||||||
min_p=args.min_p,
|
min_p=args.min_p,
|
||||||
|
@ -464,8 +464,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
text = ""
|
text = ""
|
||||||
tic = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
for n, (segment, token, logprobs) in enumerate(
|
for gen_response in stream_generate(
|
||||||
stream_generate(
|
|
||||||
model=self.model,
|
model=self.model,
|
||||||
tokenizer=self.tokenizer,
|
tokenizer=self.tokenizer,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@ -475,14 +474,11 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
repetition_context_size=self.repetition_context_size,
|
repetition_context_size=self.repetition_context_size,
|
||||||
logit_bias=self.logit_bias,
|
logit_bias=self.logit_bias,
|
||||||
prompt_cache=self.prompt_cache.cache,
|
prompt_cache=self.prompt_cache.cache,
|
||||||
),
|
|
||||||
):
|
):
|
||||||
if n == 0:
|
text += gen_response.text
|
||||||
prompt_time = time.perf_counter() - tic
|
|
||||||
tic = time.perf_counter()
|
|
||||||
|
|
||||||
text += segment
|
|
||||||
logging.debug(text)
|
logging.debug(text)
|
||||||
|
token = gen_response.token
|
||||||
|
logprobs = gen_response.logprobs
|
||||||
tokens.append(token)
|
tokens.append(token)
|
||||||
|
|
||||||
if self.logprobs > 0:
|
if self.logprobs > 0:
|
||||||
@ -523,13 +519,9 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
self.prompt_cache.tokens.extend(tokens)
|
self.prompt_cache.tokens.extend(tokens)
|
||||||
|
|
||||||
gen_time = time.perf_counter() - tic
|
logging.debug(f"Prompt: {gen_response.prompt_tps:.3f} tokens-per-sec")
|
||||||
prompt_tps = len(prompt) / prompt_time
|
logging.debug(f"Generation: {gen_response.generation_tps:.3f} tokens-per-sec")
|
||||||
gen_tps = len(tokens) / gen_time
|
logging.debug(f"Peak memory: {gen_response.peak_memory:.3f} GB")
|
||||||
peak_mem = mx.metal.get_peak_memory() / 1e9
|
|
||||||
logging.debug(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
|
||||||
logging.debug(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
|
||||||
logging.debug(f"Peak memory: {peak_mem:.3f} GB")
|
|
||||||
|
|
||||||
if self.stream:
|
if self.stream:
|
||||||
response = self.generate_response(segment, finish_reason)
|
response = self.generate_response(segment, finish_reason)
|
||||||
|
@ -8,6 +8,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import shutil
|
import shutil
|
||||||
import time
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union
|
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union
|
||||||
@ -44,6 +45,32 @@ class ModelNotFoundError(Exception):
|
|||||||
super().__init__(self.message)
|
super().__init__(self.message)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GenerationResponse:
|
||||||
|
"""
|
||||||
|
The output of :func:`stream_generate`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): The next segment of decoded text. This can be an empty string.
|
||||||
|
token (int): The next token.
|
||||||
|
logprobs (mx.array): A vector of log probabilities.
|
||||||
|
prompt_tokens (int): The number of tokens in the prompt.
|
||||||
|
prompt_tps (float): The prompt processing tokens-per-second.
|
||||||
|
generation_tokens (int): The number of generated tokens.
|
||||||
|
generation_tps (float): The tokens-per-second for generation.
|
||||||
|
peak_memory (float): The peak memory used so far in GB.
|
||||||
|
"""
|
||||||
|
|
||||||
|
text: str
|
||||||
|
token: int
|
||||||
|
logprobs: mx.array
|
||||||
|
prompt_tokens: int
|
||||||
|
prompt_tps: float
|
||||||
|
generation_tokens: int
|
||||||
|
generation_tps: float
|
||||||
|
peak_memory: float
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def wired_limit(model: nn.Module, streams: Optional[List[mx.Stream]] = None):
|
def wired_limit(model: nn.Module, streams: Optional[List[mx.Stream]] = None):
|
||||||
"""
|
"""
|
||||||
@ -290,17 +317,20 @@ def stream_generate(
|
|||||||
if not isinstance(tokenizer, TokenizerWrapper):
|
if not isinstance(tokenizer, TokenizerWrapper):
|
||||||
tokenizer = TokenizerWrapper(tokenizer)
|
tokenizer = TokenizerWrapper(tokenizer)
|
||||||
|
|
||||||
prompt_tokens = mx.array(
|
prompt = mx.array(prompt if isinstance(prompt, list) else tokenizer.encode(prompt))
|
||||||
prompt if isinstance(prompt, list) else tokenizer.encode(prompt)
|
|
||||||
)
|
|
||||||
detokenizer = tokenizer.detokenizer
|
detokenizer = tokenizer.detokenizer
|
||||||
|
|
||||||
with wired_limit(model, [generation_stream]):
|
with wired_limit(model, [generation_stream]):
|
||||||
detokenizer.reset()
|
detokenizer.reset()
|
||||||
for n, (token, logits) in zip(
|
tic = time.perf_counter()
|
||||||
|
for n, (token, logprobs) in zip(
|
||||||
range(max_tokens),
|
range(max_tokens),
|
||||||
generate_step(prompt_tokens, model, **kwargs),
|
generate_step(prompt, model, **kwargs),
|
||||||
):
|
):
|
||||||
|
if n == 0:
|
||||||
|
prompt_time = time.perf_counter() - tic
|
||||||
|
prompt_tps = prompt.size / prompt_time
|
||||||
|
tic = time.perf_counter()
|
||||||
if token == tokenizer.eos_token_id:
|
if token == tokenizer.eos_token_id:
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -309,17 +339,34 @@ def stream_generate(
|
|||||||
if n == (max_tokens - 1):
|
if n == (max_tokens - 1):
|
||||||
break
|
break
|
||||||
|
|
||||||
yield detokenizer.last_segment, token, logits
|
yield GenerationResponse(
|
||||||
|
text=detokenizer.last_segment,
|
||||||
|
token=token,
|
||||||
|
logprobs=logprobs,
|
||||||
|
prompt_tokens=prompt.size,
|
||||||
|
prompt_tps=prompt_tps,
|
||||||
|
generation_tokens=n + 1,
|
||||||
|
generation_tps=(n + 1) / (time.perf_counter() - tic),
|
||||||
|
peak_memory=mx.metal.get_peak_memory() / 1e9,
|
||||||
|
)
|
||||||
|
|
||||||
detokenizer.finalize()
|
detokenizer.finalize()
|
||||||
yield detokenizer.last_segment, token, logits
|
yield GenerationResponse(
|
||||||
|
text=detokenizer.last_segment,
|
||||||
|
token=token,
|
||||||
|
logprobs=logprobs,
|
||||||
|
prompt_tokens=prompt.size,
|
||||||
|
prompt_tps=prompt_tps,
|
||||||
|
generation_tokens=n + 1,
|
||||||
|
generation_tps=(n + 1) / (time.perf_counter() - tic),
|
||||||
|
peak_memory=mx.metal.get_peak_memory() / 1e9,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
||||||
prompt: str,
|
prompt: str,
|
||||||
max_tokens: int = 100,
|
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
formatter: Optional[Callable] = None,
|
formatter: Optional[Callable] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -331,67 +378,41 @@ def generate(
|
|||||||
model (nn.Module): The language model.
|
model (nn.Module): The language model.
|
||||||
tokenizer (PreTrainedTokenizer): The tokenizer.
|
tokenizer (PreTrainedTokenizer): The tokenizer.
|
||||||
prompt (str): The string prompt.
|
prompt (str): The string prompt.
|
||||||
max_tokens (int): The maximum number of tokens. Default: ``100``.
|
|
||||||
verbose (bool): If ``True``, print tokens and timing information.
|
verbose (bool): If ``True``, print tokens and timing information.
|
||||||
Default: ``False``.
|
Default: ``False``.
|
||||||
formatter (Optional[Callable]): A function which takes a token and a
|
kwargs: The remaining options get passed to :func:`stream_generate`.
|
||||||
probability and displays it.
|
See :func:`stream_generate` for more details.
|
||||||
kwargs: The remaining options get passed to :func:`generate_step`.
|
|
||||||
See :func:`generate_step` for more details.
|
|
||||||
"""
|
"""
|
||||||
if not isinstance(tokenizer, TokenizerWrapper):
|
if formatter is not None:
|
||||||
tokenizer = TokenizerWrapper(tokenizer)
|
print(
|
||||||
|
"Text formatting has been deprecated and will be removed in the next version."
|
||||||
|
)
|
||||||
if verbose:
|
if verbose:
|
||||||
print("=" * 10)
|
print("=" * 10)
|
||||||
print("Prompt:", prompt)
|
print("Prompt:", prompt)
|
||||||
|
|
||||||
prompt_tokens = mx.array(tokenizer.encode(prompt))
|
full_text = ""
|
||||||
detokenizer = tokenizer.detokenizer
|
for response in stream_generate(model, tokenizer, prompt, **kwargs):
|
||||||
|
if verbose:
|
||||||
with wired_limit(model, [generation_stream]):
|
print(response.text, end="", flush=True)
|
||||||
tic = time.perf_counter()
|
full_text += response.text
|
||||||
detokenizer.reset()
|
|
||||||
for n, (token, logprobs) in zip(
|
|
||||||
range(max_tokens),
|
|
||||||
generate_step(prompt_tokens, model, **kwargs),
|
|
||||||
):
|
|
||||||
if n == 0:
|
|
||||||
prompt_time = time.perf_counter() - tic
|
|
||||||
tic = time.perf_counter()
|
|
||||||
if token == tokenizer.eos_token_id:
|
|
||||||
break
|
|
||||||
detokenizer.add_token(token)
|
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
if formatter:
|
print()
|
||||||
# We have to finalize so that the prob corresponds to the last segment
|
|
||||||
detokenizer.finalize()
|
|
||||||
prob = mx.exp(logprobs[token]).item()
|
|
||||||
formatter(detokenizer.last_segment, prob)
|
|
||||||
else:
|
|
||||||
print(detokenizer.last_segment, end="", flush=True)
|
|
||||||
|
|
||||||
token_count = n + 1
|
|
||||||
detokenizer.finalize()
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
gen_time = time.perf_counter() - tic
|
|
||||||
print(detokenizer.last_segment, flush=True)
|
|
||||||
print("=" * 10)
|
print("=" * 10)
|
||||||
if token_count == 0:
|
if len(full_text) == 0:
|
||||||
print("No tokens generated for this prompt")
|
print("No text generated for this prompt")
|
||||||
return
|
return
|
||||||
prompt_tps = prompt_tokens.size / prompt_time
|
|
||||||
gen_tps = (token_count - 1) / gen_time
|
|
||||||
print(
|
print(
|
||||||
f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec"
|
f"Prompt: {response.prompt_tokens} tokens, "
|
||||||
|
f"{response.prompt_tps:.3f} tokens-per-sec"
|
||||||
)
|
)
|
||||||
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
|
print(
|
||||||
peak_mem = mx.metal.get_peak_memory() / 1e9
|
f"Generation: {response.generation_tokens} tokens, "
|
||||||
print(f"Peak memory: {peak_mem:.3f} GB")
|
f"{response.generation_tps:.3f} tokens-per-sec"
|
||||||
|
)
|
||||||
return detokenizer.text
|
print(f"Peak memory: {response.peak_memory:.3f} GB")
|
||||||
|
return full_text
|
||||||
|
|
||||||
|
|
||||||
def load_config(model_path: Path) -> dict:
|
def load_config(model_path: Path) -> dict:
|
||||||
|
Loading…
Reference in New Issue
Block a user