mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 10:56:38 +08:00
fixes
This commit is contained in:
parent
431988721f
commit
566af15c34
@ -61,7 +61,7 @@ prompt = tokenizer.apply_chat_template(
|
|||||||
messages, tokenize=False, add_generation_prompt=True
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
)
|
)
|
||||||
|
|
||||||
response = generate(model, tokenizer, prompt=prompt, verbose=True)
|
text = generate(model, tokenizer, prompt=prompt, verbose=True)
|
||||||
```
|
```
|
||||||
|
|
||||||
To see a description of all the arguments you can do:
|
To see a description of all the arguments you can do:
|
||||||
@ -117,8 +117,8 @@ prompt = tokenizer.apply_chat_template(
|
|||||||
messages, tokenize=False, add_generation_prompt=True
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
)
|
)
|
||||||
|
|
||||||
for text, *_ in stream_generate(model, tokenizer, prompt, max_tokens=512):
|
for response in stream_generate(model, tokenizer, prompt, max_tokens=512):
|
||||||
print(t, end="", flush=True)
|
print(response.text, end="", flush=True)
|
||||||
print()
|
print()
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -74,7 +74,7 @@ def main():
|
|||||||
prompt = tokenizer.apply_chat_template(
|
prompt = tokenizer.apply_chat_template(
|
||||||
messages, tokenize=False, add_generation_prompt=True
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
)
|
)
|
||||||
for response, *_ in stream_generate(
|
for response in stream_generate(
|
||||||
model,
|
model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
prompt,
|
prompt,
|
||||||
@ -83,7 +83,7 @@ def main():
|
|||||||
top_p=args.top_p,
|
top_p=args.top_p,
|
||||||
prompt_cache=prompt_cache,
|
prompt_cache=prompt_cache,
|
||||||
):
|
):
|
||||||
print(response, flush=True, end="")
|
print(response.text, flush=True, end="")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
@ -475,7 +475,8 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
logit_bias=self.logit_bias,
|
logit_bias=self.logit_bias,
|
||||||
prompt_cache=self.prompt_cache.cache,
|
prompt_cache=self.prompt_cache.cache,
|
||||||
):
|
):
|
||||||
text += gen_response.text
|
segment = gen_response.text
|
||||||
|
text += segment
|
||||||
logging.debug(text)
|
logging.debug(text)
|
||||||
token = gen_response.token
|
token = gen_response.token
|
||||||
logprobs = gen_response.logprobs
|
logprobs = gen_response.logprobs
|
||||||
|
@ -378,29 +378,28 @@ def generate(
|
|||||||
model (nn.Module): The language model.
|
model (nn.Module): The language model.
|
||||||
tokenizer (PreTrainedTokenizer): The tokenizer.
|
tokenizer (PreTrainedTokenizer): The tokenizer.
|
||||||
prompt (str): The string prompt.
|
prompt (str): The string prompt.
|
||||||
|
max_tokens (int): The maximum number of tokens. Default: ``100``.
|
||||||
verbose (bool): If ``True``, print tokens and timing information.
|
verbose (bool): If ``True``, print tokens and timing information.
|
||||||
Default: ``False``.
|
Default: ``False``.
|
||||||
kwargs: The remaining options get passed to :func:`stream_generate`.
|
kwargs: The remaining options get passed to :func:`stream_generate`.
|
||||||
See :func:`stream_generate` for more details.
|
See :func:`stream_generate` for more details.
|
||||||
"""
|
"""
|
||||||
if formatter is not None:
|
if formatter is not None:
|
||||||
print(
|
print("Text formatting is deprecated and will be removed in the next version.")
|
||||||
"Text formatting has been deprecated and will be removed in the next version."
|
|
||||||
)
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print("=" * 10)
|
print("=" * 10)
|
||||||
print("Prompt:", prompt)
|
print("Prompt:", prompt)
|
||||||
|
|
||||||
full_text = ""
|
text = ""
|
||||||
for response in stream_generate(model, tokenizer, prompt, **kwargs):
|
for response in stream_generate(model, tokenizer, prompt, **kwargs):
|
||||||
if verbose:
|
if verbose:
|
||||||
print(response.text, end="", flush=True)
|
print(response.text, end="", flush=True)
|
||||||
full_text += response.text
|
text += response.text
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print()
|
print()
|
||||||
print("=" * 10)
|
print("=" * 10)
|
||||||
if len(full_text) == 0:
|
if len(text) == 0:
|
||||||
print("No text generated for this prompt")
|
print("No text generated for this prompt")
|
||||||
return
|
return
|
||||||
print(
|
print(
|
||||||
@ -412,7 +411,7 @@ def generate(
|
|||||||
f"{response.generation_tps:.3f} tokens-per-sec"
|
f"{response.generation_tps:.3f} tokens-per-sec"
|
||||||
)
|
)
|
||||||
print(f"Peak memory: {response.peak_memory:.3f} GB")
|
print(f"Peak memory: {response.peak_memory:.3f} GB")
|
||||||
return full_text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def load_config(model_path: Path) -> dict:
|
def load_config(model_path: Path) -> dict:
|
||||||
|
Loading…
Reference in New Issue
Block a user