mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
mlx_lm: Add Streaming Capability to Generate Function (#807)
* Add streaming feature to text generation function * separate stream and regular functions --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
8353bbbf93
commit
43d6deb3c1
@ -27,7 +27,7 @@ You can use `mlx-lm` as a module:
|
|||||||
```python
|
```python
|
||||||
from mlx_lm import load, generate
|
from mlx_lm import load, generate
|
||||||
|
|
||||||
model, tokenizer = load("mistralai/Mistral-7B-Instruct-v0.1")
|
model, tokenizer = load("mlx-community/Mistral-7B-Instruct-v0.3-4bit")
|
||||||
|
|
||||||
response = generate(model, tokenizer, prompt="hello", verbose=True)
|
response = generate(model, tokenizer, prompt="hello", verbose=True)
|
||||||
```
|
```
|
||||||
@ -46,13 +46,14 @@ You can convert models in the Python API with:
|
|||||||
```python
|
```python
|
||||||
from mlx_lm import convert
|
from mlx_lm import convert
|
||||||
|
|
||||||
upload_repo = "mlx-community/My-Mistral-7B-v0.1-4bit"
|
repo = "mistralai/Mistral-7B-Instruct-v0.3"
|
||||||
|
upload_repo = "mlx-community/My-Mistral-7B-Instruct-v0.3-4bit"
|
||||||
|
|
||||||
convert("mistralai/Mistral-7B-v0.1", quantize=True, upload_repo=upload_repo)
|
convert(repo, quantize=True, upload_repo=upload_repo)
|
||||||
```
|
```
|
||||||
|
|
||||||
This will generate a 4-bit quantized Mistral-7B and upload it to the
|
This will generate a 4-bit quantized Mistral 7B and upload it to the repo
|
||||||
repo `mlx-community/My-Mistral-7B-v0.1-4bit`. It will also save the
|
`mlx-community/My-Mistral-7B-Instruct-v0.3-4bit`. It will also save the
|
||||||
converted model in the path `mlx_model` by default.
|
converted model in the path `mlx_model` by default.
|
||||||
|
|
||||||
To see a description of all the arguments you can do:
|
To see a description of all the arguments you can do:
|
||||||
@ -61,12 +62,30 @@ To see a description of all the arguments you can do:
|
|||||||
>>> help(convert)
|
>>> help(convert)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Streaming
|
||||||
|
|
||||||
|
For streaming generation, use the `stream_generate` function. This returns a
|
||||||
|
generator object which streams the output text. For example,
|
||||||
|
|
||||||
|
```python
|
||||||
|
from mlx_lm import load, stream_generate
|
||||||
|
|
||||||
|
repo = "mlx-community/Mistral-7B-Instruct-v0.3-4bit"
|
||||||
|
model, tokenizer = load(repo)
|
||||||
|
|
||||||
|
prompt = "Write a story about Einstein"
|
||||||
|
|
||||||
|
for t in stream_generate(model, tokenizer, prompt, max_tokens=512):
|
||||||
|
print(t, end="", flush=True)
|
||||||
|
print()
|
||||||
|
```
|
||||||
|
|
||||||
### Command Line
|
### Command Line
|
||||||
|
|
||||||
You can also use `mlx-lm` from the command line with:
|
You can also use `mlx-lm` from the command line with:
|
||||||
|
|
||||||
```
|
```
|
||||||
mlx_lm.generate --model mistralai/Mistral-7B-Instruct-v0.1 --prompt "hello"
|
mlx_lm.generate --model mistralai/Mistral-7B-Instruct-v0.3 --prompt "hello"
|
||||||
```
|
```
|
||||||
|
|
||||||
This will download a Mistral 7B model from the Hugging Face Hub and generate
|
This will download a Mistral 7B model from the Hugging Face Hub and generate
|
||||||
@ -81,7 +100,7 @@ mlx_lm.generate --help
|
|||||||
To quantize a model from the command line run:
|
To quantize a model from the command line run:
|
||||||
|
|
||||||
```
|
```
|
||||||
mlx_lm.convert --hf-path mistralai/Mistral-7B-Instruct-v0.1 -q
|
mlx_lm.convert --hf-path mistralai/Mistral-7B-Instruct-v0.3 -q
|
||||||
```
|
```
|
||||||
|
|
||||||
For more options run:
|
For more options run:
|
||||||
@ -96,7 +115,7 @@ You can upload new models to Hugging Face by specifying `--upload-repo` to
|
|||||||
|
|
||||||
```
|
```
|
||||||
mlx_lm.convert \
|
mlx_lm.convert \
|
||||||
--hf-path mistralai/Mistral-7B-v0.1 \
|
--hf-path mistralai/Mistral-7B-Instruct-v0.3 \
|
||||||
-q \
|
-q \
|
||||||
--upload-repo mlx-community/my-4bit-mistral
|
--upload-repo mlx-community/my-4bit-mistral
|
||||||
```
|
```
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from .utils import convert, generate, load
|
from .utils import convert, generate, load, stream_generate
|
||||||
from .version import __version__
|
from .version import __version__
|
||||||
|
@ -149,10 +149,10 @@ def main():
|
|||||||
model,
|
model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
prompt,
|
prompt,
|
||||||
args.temp,
|
|
||||||
args.max_tokens,
|
args.max_tokens,
|
||||||
True,
|
verbose=True,
|
||||||
formatter=formatter,
|
formatter=formatter,
|
||||||
|
temp=args.temp,
|
||||||
top_p=args.top_p,
|
top_p=args.top_p,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -136,15 +136,19 @@ def generate_step(
|
|||||||
logit_bias: Optional[Dict[int, float]] = None,
|
logit_bias: Optional[Dict[int, float]] = None,
|
||||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||||
"""
|
"""
|
||||||
A generator producing text based on the given prompt from the model.
|
A generator producing token ids based on the given prompt from the model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompt (mx.array): The input prompt.
|
prompt (mx.array): The input prompt.
|
||||||
model (nn.Module): The model to use for generation.
|
model (nn.Module): The model to use for generation.
|
||||||
temp (float): The temperature for sampling, if 0 the argmax is used.
|
temp (float): The temperature for sampling, if 0 the argmax is used.
|
||||||
repetition_penalty (float, optional): The penalty factor for repeating tokens.
|
Default: ``0``.
|
||||||
repetition_context_size (int, optional): The number of tokens to consider for repetition penalty (default 20).
|
repetition_penalty (float, optional): The penalty factor for repeating
|
||||||
top_p (float, optional): Nulceus sampling, higher means model considers more less likely words
|
tokens.
|
||||||
|
repetition_context_size (int, optional): The number of tokens to
|
||||||
|
consider for repetition penalty. Default: ``20``.
|
||||||
|
top_p (float, optional): Nulceus sampling, higher means model considers
|
||||||
|
more less likely words.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
Generator[Tuple[mx.array, mx.array]]: A generator producing
|
Generator[Tuple[mx.array, mx.array]]: A generator producing
|
||||||
@ -218,34 +222,71 @@ def generate_step(
|
|||||||
y, p = next_y, next_p
|
y, p = next_y, next_p
|
||||||
|
|
||||||
|
|
||||||
|
def stream_generate(
|
||||||
|
model: nn.Module,
|
||||||
|
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
||||||
|
prompt: str,
|
||||||
|
max_tokens: int = 100,
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[str, Generator[str, None, None]]:
|
||||||
|
"""
|
||||||
|
A generator producing text based on the given prompt from the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (mx.array): The input prompt.
|
||||||
|
model (nn.Module): The model to use for generation.
|
||||||
|
max_tokens (int): The ma
|
||||||
|
kwargs: The remaining options get passed to :func:`generate_step`.
|
||||||
|
See :func:`generate_step` for more details.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Generator[Tuple[mx.array, mx.array]]: A generator producing text.
|
||||||
|
"""
|
||||||
|
if not isinstance(tokenizer, TokenizerWrapper):
|
||||||
|
tokenizer = TokenizerWrapper(tokenizer)
|
||||||
|
|
||||||
|
prompt_tokens = mx.array(tokenizer.encode(prompt))
|
||||||
|
detokenizer = tokenizer.detokenizer
|
||||||
|
|
||||||
|
detokenizer.reset()
|
||||||
|
for (token, prob), n in zip(
|
||||||
|
generate_step(prompt_tokens, model, **kwargs),
|
||||||
|
range(max_tokens),
|
||||||
|
):
|
||||||
|
if token == tokenizer.eos_token_id:
|
||||||
|
break
|
||||||
|
detokenizer.add_token(token)
|
||||||
|
|
||||||
|
# Yield the last segment if streaming
|
||||||
|
yield detokenizer.last_segment
|
||||||
|
|
||||||
|
detokenizer.finalize()
|
||||||
|
yield detokenizer.last_segment
|
||||||
|
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
||||||
prompt: str,
|
prompt: str,
|
||||||
temp: float = 0.0,
|
|
||||||
max_tokens: int = 100,
|
max_tokens: int = 100,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
formatter: Optional[Callable] = None,
|
formatter: Optional[Callable] = None,
|
||||||
repetition_penalty: Optional[float] = None,
|
**kwargs,
|
||||||
repetition_context_size: Optional[int] = None,
|
) -> Union[str, Generator[str, None, None]]:
|
||||||
top_p: float = 1.0,
|
|
||||||
logit_bias: Optional[Dict[int, float]] = None,
|
|
||||||
) -> str:
|
|
||||||
"""
|
"""
|
||||||
Generate text from the model.
|
Generate a complete response from the model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (nn.Module): The language model.
|
model (nn.Module): The language model.
|
||||||
tokenizer (PreTrainedTokenizer): The tokenizer.
|
tokenizer (PreTrainedTokenizer): The tokenizer.
|
||||||
prompt (str): The string prompt.
|
prompt (str): The string prompt.
|
||||||
temp (float): The temperature for sampling (default 0).
|
max_tokens (int): The maximum number of tokens. Default: ``100``.
|
||||||
max_tokens (int): The maximum number of tokens (default 100).
|
verbose (bool): If ``True``, print tokens and timing information.
|
||||||
verbose (bool): If ``True``, print tokens and timing information
|
Default: ``False``.
|
||||||
(default ``False``).
|
|
||||||
formatter (Optional[Callable]): A function which takes a token and a
|
formatter (Optional[Callable]): A function which takes a token and a
|
||||||
probability and displays it.
|
probability and displays it.
|
||||||
repetition_penalty (float, optional): The penalty factor for repeating tokens.
|
kwargs: The remaining options get passed to :func:`generate_step`.
|
||||||
repetition_context_size (int, optional): The number of tokens to consider for repetition penalty.
|
See :func:`generate_step` for more details.
|
||||||
"""
|
"""
|
||||||
if not isinstance(tokenizer, TokenizerWrapper):
|
if not isinstance(tokenizer, TokenizerWrapper):
|
||||||
tokenizer = TokenizerWrapper(tokenizer)
|
tokenizer = TokenizerWrapper(tokenizer)
|
||||||
@ -261,15 +302,7 @@ def generate(
|
|||||||
detokenizer.reset()
|
detokenizer.reset()
|
||||||
|
|
||||||
for (token, prob), n in zip(
|
for (token, prob), n in zip(
|
||||||
generate_step(
|
generate_step(prompt_tokens, model, **kwargs),
|
||||||
prompt_tokens,
|
|
||||||
model,
|
|
||||||
temp,
|
|
||||||
repetition_penalty,
|
|
||||||
repetition_context_size,
|
|
||||||
top_p,
|
|
||||||
logit_bias,
|
|
||||||
),
|
|
||||||
range(max_tokens),
|
range(max_tokens),
|
||||||
):
|
):
|
||||||
if n == 0:
|
if n == 0:
|
||||||
|
Loading…
Reference in New Issue
Block a user