mlx_lm: Add Streaming Capability to Generate Function (#807)

* Add streaming feature to text generation function

* separate stream and regular functions

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Michał Kurc 2024-06-03 18:04:39 +02:00 committed by GitHub
parent 8353bbbf93
commit 43d6deb3c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 89 additions and 37 deletions

View File

@ -27,7 +27,7 @@ You can use `mlx-lm` as a module:
```python ```python
from mlx_lm import load, generate from mlx_lm import load, generate
model, tokenizer = load("mistralai/Mistral-7B-Instruct-v0.1") model, tokenizer = load("mlx-community/Mistral-7B-Instruct-v0.3-4bit")
response = generate(model, tokenizer, prompt="hello", verbose=True) response = generate(model, tokenizer, prompt="hello", verbose=True)
``` ```
@ -46,13 +46,14 @@ You can convert models in the Python API with:
```python ```python
from mlx_lm import convert from mlx_lm import convert
upload_repo = "mlx-community/My-Mistral-7B-v0.1-4bit" repo = "mistralai/Mistral-7B-Instruct-v0.3"
upload_repo = "mlx-community/My-Mistral-7B-Instruct-v0.3-4bit"
convert("mistralai/Mistral-7B-v0.1", quantize=True, upload_repo=upload_repo) convert(repo, quantize=True, upload_repo=upload_repo)
``` ```
This will generate a 4-bit quantized Mistral-7B and upload it to the This will generate a 4-bit quantized Mistral 7B and upload it to the repo
repo `mlx-community/My-Mistral-7B-v0.1-4bit`. It will also save the `mlx-community/My-Mistral-7B-Instruct-v0.3-4bit`. It will also save the
converted model in the path `mlx_model` by default. converted model in the path `mlx_model` by default.
To see a description of all the arguments you can do: To see a description of all the arguments you can do:
@ -61,12 +62,30 @@ To see a description of all the arguments you can do:
>>> help(convert) >>> help(convert)
``` ```
#### Streaming
For streaming generation, use the `stream_generate` function. This returns a
generator object which streams the output text. For example,
```python
from mlx_lm import load, stream_generate
repo = "mlx-community/Mistral-7B-Instruct-v0.3-4bit"
model, tokenizer = load(repo)
prompt = "Write a story about Einstein"
for t in stream_generate(model, tokenizer, prompt, max_tokens=512):
print(t, end="", flush=True)
print()
```
### Command Line ### Command Line
You can also use `mlx-lm` from the command line with: You can also use `mlx-lm` from the command line with:
``` ```
mlx_lm.generate --model mistralai/Mistral-7B-Instruct-v0.1 --prompt "hello" mlx_lm.generate --model mistralai/Mistral-7B-Instruct-v0.3 --prompt "hello"
``` ```
This will download a Mistral 7B model from the Hugging Face Hub and generate This will download a Mistral 7B model from the Hugging Face Hub and generate
@ -81,7 +100,7 @@ mlx_lm.generate --help
To quantize a model from the command line run: To quantize a model from the command line run:
``` ```
mlx_lm.convert --hf-path mistralai/Mistral-7B-Instruct-v0.1 -q mlx_lm.convert --hf-path mistralai/Mistral-7B-Instruct-v0.3 -q
``` ```
For more options run: For more options run:
@ -96,7 +115,7 @@ You can upload new models to Hugging Face by specifying `--upload-repo` to
``` ```
mlx_lm.convert \ mlx_lm.convert \
--hf-path mistralai/Mistral-7B-v0.1 \ --hf-path mistralai/Mistral-7B-Instruct-v0.3 \
-q \ -q \
--upload-repo mlx-community/my-4bit-mistral --upload-repo mlx-community/my-4bit-mistral
``` ```

View File

@ -1,4 +1,4 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
from .utils import convert, generate, load from .utils import convert, generate, load, stream_generate
from .version import __version__ from .version import __version__

View File

@ -149,10 +149,10 @@ def main():
model, model,
tokenizer, tokenizer,
prompt, prompt,
args.temp,
args.max_tokens, args.max_tokens,
True, verbose=True,
formatter=formatter, formatter=formatter,
temp=args.temp,
top_p=args.top_p, top_p=args.top_p,
) )

View File

@ -136,15 +136,19 @@ def generate_step(
logit_bias: Optional[Dict[int, float]] = None, logit_bias: Optional[Dict[int, float]] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]: ) -> Generator[Tuple[mx.array, mx.array], None, None]:
""" """
A generator producing text based on the given prompt from the model. A generator producing token ids based on the given prompt from the model.
Args: Args:
prompt (mx.array): The input prompt. prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation. model (nn.Module): The model to use for generation.
temp (float): The temperature for sampling, if 0 the argmax is used. temp (float): The temperature for sampling, if 0 the argmax is used.
repetition_penalty (float, optional): The penalty factor for repeating tokens. Default: ``0``.
repetition_context_size (int, optional): The number of tokens to consider for repetition penalty (default 20). repetition_penalty (float, optional): The penalty factor for repeating
top_p (float, optional): Nulceus sampling, higher means model considers more less likely words tokens.
repetition_context_size (int, optional): The number of tokens to
consider for repetition penalty. Default: ``20``.
top_p (float, optional): Nulceus sampling, higher means model considers
more less likely words.
Yields: Yields:
Generator[Tuple[mx.array, mx.array]]: A generator producing Generator[Tuple[mx.array, mx.array]]: A generator producing
@ -218,34 +222,71 @@ def generate_step(
y, p = next_y, next_p y, p = next_y, next_p
def stream_generate(
model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: str,
max_tokens: int = 100,
**kwargs,
) -> Union[str, Generator[str, None, None]]:
"""
A generator producing text based on the given prompt from the model.
Args:
prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation.
max_tokens (int): The ma
kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details.
Yields:
Generator[Tuple[mx.array, mx.array]]: A generator producing text.
"""
if not isinstance(tokenizer, TokenizerWrapper):
tokenizer = TokenizerWrapper(tokenizer)
prompt_tokens = mx.array(tokenizer.encode(prompt))
detokenizer = tokenizer.detokenizer
detokenizer.reset()
for (token, prob), n in zip(
generate_step(prompt_tokens, model, **kwargs),
range(max_tokens),
):
if token == tokenizer.eos_token_id:
break
detokenizer.add_token(token)
# Yield the last segment if streaming
yield detokenizer.last_segment
detokenizer.finalize()
yield detokenizer.last_segment
def generate( def generate(
model: nn.Module, model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: str, prompt: str,
temp: float = 0.0,
max_tokens: int = 100, max_tokens: int = 100,
verbose: bool = False, verbose: bool = False,
formatter: Optional[Callable] = None, formatter: Optional[Callable] = None,
repetition_penalty: Optional[float] = None, **kwargs,
repetition_context_size: Optional[int] = None, ) -> Union[str, Generator[str, None, None]]:
top_p: float = 1.0,
logit_bias: Optional[Dict[int, float]] = None,
) -> str:
""" """
Generate text from the model. Generate a complete response from the model.
Args: Args:
model (nn.Module): The language model. model (nn.Module): The language model.
tokenizer (PreTrainedTokenizer): The tokenizer. tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (str): The string prompt. prompt (str): The string prompt.
temp (float): The temperature for sampling (default 0). max_tokens (int): The maximum number of tokens. Default: ``100``.
max_tokens (int): The maximum number of tokens (default 100). verbose (bool): If ``True``, print tokens and timing information.
verbose (bool): If ``True``, print tokens and timing information Default: ``False``.
(default ``False``).
formatter (Optional[Callable]): A function which takes a token and a formatter (Optional[Callable]): A function which takes a token and a
probability and displays it. probability and displays it.
repetition_penalty (float, optional): The penalty factor for repeating tokens. kwargs: The remaining options get passed to :func:`generate_step`.
repetition_context_size (int, optional): The number of tokens to consider for repetition penalty. See :func:`generate_step` for more details.
""" """
if not isinstance(tokenizer, TokenizerWrapper): if not isinstance(tokenizer, TokenizerWrapper):
tokenizer = TokenizerWrapper(tokenizer) tokenizer = TokenizerWrapper(tokenizer)
@ -261,15 +302,7 @@ def generate(
detokenizer.reset() detokenizer.reset()
for (token, prob), n in zip( for (token, prob), n in zip(
generate_step( generate_step(prompt_tokens, model, **kwargs),
prompt_tokens,
model,
temp,
repetition_penalty,
repetition_context_size,
top_p,
logit_bias,
),
range(max_tokens), range(max_tokens),
): ):
if n == 0: if n == 0: