mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
fixes
This commit is contained in:
parent
431988721f
commit
566af15c34
@ -61,7 +61,7 @@ prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
response = generate(model, tokenizer, prompt=prompt, verbose=True)
|
||||
text = generate(model, tokenizer, prompt=prompt, verbose=True)
|
||||
```
|
||||
|
||||
To see a description of all the arguments you can do:
|
||||
@ -117,8 +117,8 @@ prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
for text, *_ in stream_generate(model, tokenizer, prompt, max_tokens=512):
|
||||
print(t, end="", flush=True)
|
||||
for response in stream_generate(model, tokenizer, prompt, max_tokens=512):
|
||||
print(response.text, end="", flush=True)
|
||||
print()
|
||||
```
|
||||
|
||||
|
@ -74,7 +74,7 @@ def main():
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
for response, *_ in stream_generate(
|
||||
for response in stream_generate(
|
||||
model,
|
||||
tokenizer,
|
||||
prompt,
|
||||
@ -83,7 +83,7 @@ def main():
|
||||
top_p=args.top_p,
|
||||
prompt_cache=prompt_cache,
|
||||
):
|
||||
print(response, flush=True, end="")
|
||||
print(response.text, flush=True, end="")
|
||||
print()
|
||||
|
||||
|
||||
|
@ -475,7 +475,8 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
logit_bias=self.logit_bias,
|
||||
prompt_cache=self.prompt_cache.cache,
|
||||
):
|
||||
text += gen_response.text
|
||||
segment = gen_response.text
|
||||
text += segment
|
||||
logging.debug(text)
|
||||
token = gen_response.token
|
||||
logprobs = gen_response.logprobs
|
||||
|
@ -378,29 +378,28 @@ def generate(
|
||||
model (nn.Module): The language model.
|
||||
tokenizer (PreTrainedTokenizer): The tokenizer.
|
||||
prompt (str): The string prompt.
|
||||
max_tokens (int): The maximum number of tokens. Default: ``100``.
|
||||
verbose (bool): If ``True``, print tokens and timing information.
|
||||
Default: ``False``.
|
||||
kwargs: The remaining options get passed to :func:`stream_generate`.
|
||||
See :func:`stream_generate` for more details.
|
||||
"""
|
||||
if formatter is not None:
|
||||
print(
|
||||
"Text formatting has been deprecated and will be removed in the next version."
|
||||
)
|
||||
print("Text formatting is deprecated and will be removed in the next version.")
|
||||
if verbose:
|
||||
print("=" * 10)
|
||||
print("Prompt:", prompt)
|
||||
|
||||
full_text = ""
|
||||
text = ""
|
||||
for response in stream_generate(model, tokenizer, prompt, **kwargs):
|
||||
if verbose:
|
||||
print(response.text, end="", flush=True)
|
||||
full_text += response.text
|
||||
text += response.text
|
||||
|
||||
if verbose:
|
||||
print()
|
||||
print("=" * 10)
|
||||
if len(full_text) == 0:
|
||||
if len(text) == 0:
|
||||
print("No text generated for this prompt")
|
||||
return
|
||||
print(
|
||||
@ -412,7 +411,7 @@ def generate(
|
||||
f"{response.generation_tps:.3f} tokens-per-sec"
|
||||
)
|
||||
print(f"Peak memory: {response.peak_memory:.3f} GB")
|
||||
return full_text
|
||||
return text
|
||||
|
||||
|
||||
def load_config(model_path: Path) -> dict:
|
||||
|
Loading…
Reference in New Issue
Block a user