mlx_lm: Add Streaming Capability to Generate Function (#807)

* Add streaming feature to text generation function

* separate stream and regular functions

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Michał Kurc 2024-06-03 18:04:39 +02:00 committed by GitHub
parent 8353bbbf93
commit 43d6deb3c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 89 additions and 37 deletions

View File

@ -27,7 +27,7 @@ You can use `mlx-lm` as a module:
```python
from mlx_lm import load, generate
model, tokenizer = load("mistralai/Mistral-7B-Instruct-v0.1")
model, tokenizer = load("mlx-community/Mistral-7B-Instruct-v0.3-4bit")
response = generate(model, tokenizer, prompt="hello", verbose=True)
```
@ -46,13 +46,14 @@ You can convert models in the Python API with:
```python
from mlx_lm import convert
upload_repo = "mlx-community/My-Mistral-7B-v0.1-4bit"
repo = "mistralai/Mistral-7B-Instruct-v0.3"
upload_repo = "mlx-community/My-Mistral-7B-Instruct-v0.3-4bit"
convert("mistralai/Mistral-7B-v0.1", quantize=True, upload_repo=upload_repo)
convert(repo, quantize=True, upload_repo=upload_repo)
```
This will generate a 4-bit quantized Mistral-7B and upload it to the
repo `mlx-community/My-Mistral-7B-v0.1-4bit`. It will also save the
This will generate a 4-bit quantized Mistral 7B and upload it to the repo
`mlx-community/My-Mistral-7B-Instruct-v0.3-4bit`. It will also save the
converted model in the path `mlx_model` by default.
To see a description of all the arguments you can do:
@ -61,12 +62,30 @@ To see a description of all the arguments you can do:
>>> help(convert)
```
#### Streaming
For streaming generation, use the `stream_generate` function. This returns a
generator object which streams the output text. For example,
```python
from mlx_lm import load, stream_generate
repo = "mlx-community/Mistral-7B-Instruct-v0.3-4bit"
model, tokenizer = load(repo)
prompt = "Write a story about Einstein"
for t in stream_generate(model, tokenizer, prompt, max_tokens=512):
print(t, end="", flush=True)
print()
```
### Command Line
You can also use `mlx-lm` from the command line with:
```
mlx_lm.generate --model mistralai/Mistral-7B-Instruct-v0.1 --prompt "hello"
mlx_lm.generate --model mistralai/Mistral-7B-Instruct-v0.3 --prompt "hello"
```
This will download a Mistral 7B model from the Hugging Face Hub and generate
@ -81,7 +100,7 @@ mlx_lm.generate --help
To quantize a model from the command line run:
```
mlx_lm.convert --hf-path mistralai/Mistral-7B-Instruct-v0.1 -q
mlx_lm.convert --hf-path mistralai/Mistral-7B-Instruct-v0.3 -q
```
For more options run:
@ -96,7 +115,7 @@ You can upload new models to Hugging Face by specifying `--upload-repo` to
```
mlx_lm.convert \
--hf-path mistralai/Mistral-7B-v0.1 \
--hf-path mistralai/Mistral-7B-Instruct-v0.3 \
-q \
--upload-repo mlx-community/my-4bit-mistral
```

View File

@ -1,4 +1,4 @@
# Copyright © 2023-2024 Apple Inc.
from .utils import convert, generate, load
from .utils import convert, generate, load, stream_generate
from .version import __version__

View File

@ -149,10 +149,10 @@ def main():
model,
tokenizer,
prompt,
args.temp,
args.max_tokens,
True,
verbose=True,
formatter=formatter,
temp=args.temp,
top_p=args.top_p,
)

View File

@ -136,15 +136,19 @@ def generate_step(
logit_bias: Optional[Dict[int, float]] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
"""
A generator producing text based on the given prompt from the model.
A generator producing token ids based on the given prompt from the model.
Args:
prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation.
temp (float): The temperature for sampling, if 0 the argmax is used.
repetition_penalty (float, optional): The penalty factor for repeating tokens.
repetition_context_size (int, optional): The number of tokens to consider for repetition penalty (default 20).
top_p (float, optional): Nulceus sampling, higher means model considers more less likely words
Default: ``0``.
repetition_penalty (float, optional): The penalty factor for repeating
tokens.
repetition_context_size (int, optional): The number of tokens to
consider for repetition penalty. Default: ``20``.
top_p (float, optional): Nulceus sampling, higher means model considers
more less likely words.
Yields:
Generator[Tuple[mx.array, mx.array]]: A generator producing
@ -218,34 +222,71 @@ def generate_step(
y, p = next_y, next_p
def stream_generate(
model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: str,
max_tokens: int = 100,
**kwargs,
) -> Union[str, Generator[str, None, None]]:
"""
A generator producing text based on the given prompt from the model.
Args:
prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation.
max_tokens (int): The ma
kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details.
Yields:
Generator[Tuple[mx.array, mx.array]]: A generator producing text.
"""
if not isinstance(tokenizer, TokenizerWrapper):
tokenizer = TokenizerWrapper(tokenizer)
prompt_tokens = mx.array(tokenizer.encode(prompt))
detokenizer = tokenizer.detokenizer
detokenizer.reset()
for (token, prob), n in zip(
generate_step(prompt_tokens, model, **kwargs),
range(max_tokens),
):
if token == tokenizer.eos_token_id:
break
detokenizer.add_token(token)
# Yield the last segment if streaming
yield detokenizer.last_segment
detokenizer.finalize()
yield detokenizer.last_segment
def generate(
model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: str,
temp: float = 0.0,
max_tokens: int = 100,
verbose: bool = False,
formatter: Optional[Callable] = None,
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = None,
top_p: float = 1.0,
logit_bias: Optional[Dict[int, float]] = None,
) -> str:
**kwargs,
) -> Union[str, Generator[str, None, None]]:
"""
Generate text from the model.
Generate a complete response from the model.
Args:
model (nn.Module): The language model.
tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (str): The string prompt.
temp (float): The temperature for sampling (default 0).
max_tokens (int): The maximum number of tokens (default 100).
verbose (bool): If ``True``, print tokens and timing information
(default ``False``).
max_tokens (int): The maximum number of tokens. Default: ``100``.
verbose (bool): If ``True``, print tokens and timing information.
Default: ``False``.
formatter (Optional[Callable]): A function which takes a token and a
probability and displays it.
repetition_penalty (float, optional): The penalty factor for repeating tokens.
repetition_context_size (int, optional): The number of tokens to consider for repetition penalty.
kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details.
"""
if not isinstance(tokenizer, TokenizerWrapper):
tokenizer = TokenizerWrapper(tokenizer)
@ -261,15 +302,7 @@ def generate(
detokenizer.reset()
for (token, prob), n in zip(
generate_step(
prompt_tokens,
model,
temp,
repetition_penalty,
repetition_context_size,
top_p,
logit_bias,
),
generate_step(prompt_tokens, model, **kwargs),
range(max_tokens),
):
if n == 0: