This commit is contained in:
Awni Hannun 2024-11-07 19:43:42 -08:00
parent 431988721f
commit 566af15c34
4 changed files with 13 additions and 13 deletions

View File

@ -61,7 +61,7 @@ prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True messages, tokenize=False, add_generation_prompt=True
) )
response = generate(model, tokenizer, prompt=prompt, verbose=True) text = generate(model, tokenizer, prompt=prompt, verbose=True)
``` ```
To see a description of all the arguments you can do: To see a description of all the arguments you can do:
@ -117,8 +117,8 @@ prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True messages, tokenize=False, add_generation_prompt=True
) )
for text, *_ in stream_generate(model, tokenizer, prompt, max_tokens=512): for response in stream_generate(model, tokenizer, prompt, max_tokens=512):
print(t, end="", flush=True) print(response.text, end="", flush=True)
print() print()
``` ```

View File

@ -74,7 +74,7 @@ def main():
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True messages, tokenize=False, add_generation_prompt=True
) )
for response, *_ in stream_generate( for response in stream_generate(
model, model,
tokenizer, tokenizer,
prompt, prompt,
@ -83,7 +83,7 @@ def main():
top_p=args.top_p, top_p=args.top_p,
prompt_cache=prompt_cache, prompt_cache=prompt_cache,
): ):
print(response, flush=True, end="") print(response.text, flush=True, end="")
print() print()

View File

@ -475,7 +475,8 @@ class APIHandler(BaseHTTPRequestHandler):
logit_bias=self.logit_bias, logit_bias=self.logit_bias,
prompt_cache=self.prompt_cache.cache, prompt_cache=self.prompt_cache.cache,
): ):
text += gen_response.text segment = gen_response.text
text += segment
logging.debug(text) logging.debug(text)
token = gen_response.token token = gen_response.token
logprobs = gen_response.logprobs logprobs = gen_response.logprobs

View File

@ -378,29 +378,28 @@ def generate(
model (nn.Module): The language model. model (nn.Module): The language model.
tokenizer (PreTrainedTokenizer): The tokenizer. tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (str): The string prompt. prompt (str): The string prompt.
max_tokens (int): The maximum number of tokens. Default: ``100``.
verbose (bool): If ``True``, print tokens and timing information. verbose (bool): If ``True``, print tokens and timing information.
Default: ``False``. Default: ``False``.
kwargs: The remaining options get passed to :func:`stream_generate`. kwargs: The remaining options get passed to :func:`stream_generate`.
See :func:`stream_generate` for more details. See :func:`stream_generate` for more details.
""" """
if formatter is not None: if formatter is not None:
print( print("Text formatting is deprecated and will be removed in the next version.")
"Text formatting has been deprecated and will be removed in the next version."
)
if verbose: if verbose:
print("=" * 10) print("=" * 10)
print("Prompt:", prompt) print("Prompt:", prompt)
full_text = "" text = ""
for response in stream_generate(model, tokenizer, prompt, **kwargs): for response in stream_generate(model, tokenizer, prompt, **kwargs):
if verbose: if verbose:
print(response.text, end="", flush=True) print(response.text, end="", flush=True)
full_text += response.text text += response.text
if verbose: if verbose:
print() print()
print("=" * 10) print("=" * 10)
if len(full_text) == 0: if len(text) == 0:
print("No text generated for this prompt") print("No text generated for this prompt")
return return
print( print(
@ -412,7 +411,7 @@ def generate(
f"{response.generation_tps:.3f} tokens-per-sec" f"{response.generation_tps:.3f} tokens-per-sec"
) )
print(f"Peak memory: {response.peak_memory:.3f} GB") print(f"Peak memory: {response.peak_memory:.3f} GB")
return full_text return text
def load_config(model_path: Path) -> dict: def load_config(model_path: Path) -> dict: