This commit is contained in:
Awni Hannun 2024-11-07 19:43:42 -08:00
parent 431988721f
commit 566af15c34
4 changed files with 13 additions and 13 deletions

View File

@ -61,7 +61,7 @@ prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
response = generate(model, tokenizer, prompt=prompt, verbose=True)
text = generate(model, tokenizer, prompt=prompt, verbose=True)
```
To see a description of all the arguments you can do:
@ -117,8 +117,8 @@ prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
for text, *_ in stream_generate(model, tokenizer, prompt, max_tokens=512):
print(t, end="", flush=True)
for response in stream_generate(model, tokenizer, prompt, max_tokens=512):
print(response.text, end="", flush=True)
print()
```

View File

@ -74,7 +74,7 @@ def main():
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
for response, *_ in stream_generate(
for response in stream_generate(
model,
tokenizer,
prompt,
@ -83,7 +83,7 @@ def main():
top_p=args.top_p,
prompt_cache=prompt_cache,
):
print(response, flush=True, end="")
print(response.text, flush=True, end="")
print()

View File

@ -475,7 +475,8 @@ class APIHandler(BaseHTTPRequestHandler):
logit_bias=self.logit_bias,
prompt_cache=self.prompt_cache.cache,
):
text += gen_response.text
segment = gen_response.text
text += segment
logging.debug(text)
token = gen_response.token
logprobs = gen_response.logprobs

View File

@ -378,29 +378,28 @@ def generate(
model (nn.Module): The language model.
tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (str): The string prompt.
max_tokens (int): The maximum number of tokens. Default: ``100``.
verbose (bool): If ``True``, print tokens and timing information.
Default: ``False``.
kwargs: The remaining options get passed to :func:`stream_generate`.
See :func:`stream_generate` for more details.
"""
if formatter is not None:
print(
"Text formatting has been deprecated and will be removed in the next version."
)
print("Text formatting is deprecated and will be removed in the next version.")
if verbose:
print("=" * 10)
print("Prompt:", prompt)
full_text = ""
text = ""
for response in stream_generate(model, tokenizer, prompt, **kwargs):
if verbose:
print(response.text, end="", flush=True)
full_text += response.text
text += response.text
if verbose:
print()
print("=" * 10)
if len(full_text) == 0:
if len(text) == 0:
print("No text generated for this prompt")
return
print(
@ -412,7 +411,7 @@ def generate(
f"{response.generation_tps:.3f} tokens-per-sec"
)
print(f"Peak memory: {response.peak_memory:.3f} GB")
return full_text
return text
def load_config(model_path: Path) -> dict: