From 43d6deb3c1d280def1383c84c1e87de9fcf3208b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Kurc?= Date: Mon, 3 Jun 2024 18:04:39 +0200 Subject: [PATCH] mlx_lm: Add Streaming Capability to Generate Function (#807) * Add streaming feature to text generation function * separate stream and regular functions --------- Co-authored-by: Awni Hannun --- llms/README.md | 35 +++++++++++++---- llms/mlx_lm/__init__.py | 2 +- llms/mlx_lm/generate.py | 4 +- llms/mlx_lm/utils.py | 85 ++++++++++++++++++++++++++++------------- 4 files changed, 89 insertions(+), 37 deletions(-) diff --git a/llms/README.md b/llms/README.md index 1f14a5e5..4b18ed1f 100644 --- a/llms/README.md +++ b/llms/README.md @@ -27,7 +27,7 @@ You can use `mlx-lm` as a module: ```python from mlx_lm import load, generate -model, tokenizer = load("mistralai/Mistral-7B-Instruct-v0.1") +model, tokenizer = load("mlx-community/Mistral-7B-Instruct-v0.3-4bit") response = generate(model, tokenizer, prompt="hello", verbose=True) ``` @@ -46,13 +46,14 @@ You can convert models in the Python API with: ```python from mlx_lm import convert -upload_repo = "mlx-community/My-Mistral-7B-v0.1-4bit" +repo = "mistralai/Mistral-7B-Instruct-v0.3" +upload_repo = "mlx-community/My-Mistral-7B-Instruct-v0.3-4bit" -convert("mistralai/Mistral-7B-v0.1", quantize=True, upload_repo=upload_repo) +convert(repo, quantize=True, upload_repo=upload_repo) ``` -This will generate a 4-bit quantized Mistral-7B and upload it to the -repo `mlx-community/My-Mistral-7B-v0.1-4bit`. It will also save the +This will generate a 4-bit quantized Mistral 7B and upload it to the repo +`mlx-community/My-Mistral-7B-Instruct-v0.3-4bit`. It will also save the converted model in the path `mlx_model` by default. To see a description of all the arguments you can do: @@ -61,12 +62,30 @@ To see a description of all the arguments you can do: >>> help(convert) ``` +#### Streaming + +For streaming generation, use the `stream_generate` function. This returns a +generator object which streams the output text. For example, + +```python +from mlx_lm import load, stream_generate + +repo = "mlx-community/Mistral-7B-Instruct-v0.3-4bit" +model, tokenizer = load(repo) + +prompt = "Write a story about Einstein" + +for t in stream_generate(model, tokenizer, prompt, max_tokens=512): + print(t, end="", flush=True) +print() +``` + ### Command Line You can also use `mlx-lm` from the command line with: ``` -mlx_lm.generate --model mistralai/Mistral-7B-Instruct-v0.1 --prompt "hello" +mlx_lm.generate --model mistralai/Mistral-7B-Instruct-v0.3 --prompt "hello" ``` This will download a Mistral 7B model from the Hugging Face Hub and generate @@ -81,7 +100,7 @@ mlx_lm.generate --help To quantize a model from the command line run: ``` -mlx_lm.convert --hf-path mistralai/Mistral-7B-Instruct-v0.1 -q +mlx_lm.convert --hf-path mistralai/Mistral-7B-Instruct-v0.3 -q ``` For more options run: @@ -96,7 +115,7 @@ You can upload new models to Hugging Face by specifying `--upload-repo` to ``` mlx_lm.convert \ - --hf-path mistralai/Mistral-7B-v0.1 \ + --hf-path mistralai/Mistral-7B-Instruct-v0.3 \ -q \ --upload-repo mlx-community/my-4bit-mistral ``` diff --git a/llms/mlx_lm/__init__.py b/llms/mlx_lm/__init__.py index ecf69c6d..e971c467 100644 --- a/llms/mlx_lm/__init__.py +++ b/llms/mlx_lm/__init__.py @@ -1,4 +1,4 @@ # Copyright © 2023-2024 Apple Inc. -from .utils import convert, generate, load +from .utils import convert, generate, load, stream_generate from .version import __version__ diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 629bba16..c003940b 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -149,10 +149,10 @@ def main(): model, tokenizer, prompt, - args.temp, args.max_tokens, - True, + verbose=True, formatter=formatter, + temp=args.temp, top_p=args.top_p, ) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 82c00fca..d7de95bf 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -136,15 +136,19 @@ def generate_step( logit_bias: Optional[Dict[int, float]] = None, ) -> Generator[Tuple[mx.array, mx.array], None, None]: """ - A generator producing text based on the given prompt from the model. + A generator producing token ids based on the given prompt from the model. Args: prompt (mx.array): The input prompt. model (nn.Module): The model to use for generation. temp (float): The temperature for sampling, if 0 the argmax is used. - repetition_penalty (float, optional): The penalty factor for repeating tokens. - repetition_context_size (int, optional): The number of tokens to consider for repetition penalty (default 20). - top_p (float, optional): Nulceus sampling, higher means model considers more less likely words + Default: ``0``. + repetition_penalty (float, optional): The penalty factor for repeating + tokens. + repetition_context_size (int, optional): The number of tokens to + consider for repetition penalty. Default: ``20``. + top_p (float, optional): Nulceus sampling, higher means model considers + more less likely words. Yields: Generator[Tuple[mx.array, mx.array]]: A generator producing @@ -218,34 +222,71 @@ def generate_step( y, p = next_y, next_p +def stream_generate( + model: nn.Module, + tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], + prompt: str, + max_tokens: int = 100, + **kwargs, +) -> Union[str, Generator[str, None, None]]: + """ + A generator producing text based on the given prompt from the model. + + Args: + prompt (mx.array): The input prompt. + model (nn.Module): The model to use for generation. + max_tokens (int): The ma + kwargs: The remaining options get passed to :func:`generate_step`. + See :func:`generate_step` for more details. + + Yields: + Generator[Tuple[mx.array, mx.array]]: A generator producing text. + """ + if not isinstance(tokenizer, TokenizerWrapper): + tokenizer = TokenizerWrapper(tokenizer) + + prompt_tokens = mx.array(tokenizer.encode(prompt)) + detokenizer = tokenizer.detokenizer + + detokenizer.reset() + for (token, prob), n in zip( + generate_step(prompt_tokens, model, **kwargs), + range(max_tokens), + ): + if token == tokenizer.eos_token_id: + break + detokenizer.add_token(token) + + # Yield the last segment if streaming + yield detokenizer.last_segment + + detokenizer.finalize() + yield detokenizer.last_segment + + def generate( model: nn.Module, tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], prompt: str, - temp: float = 0.0, max_tokens: int = 100, verbose: bool = False, formatter: Optional[Callable] = None, - repetition_penalty: Optional[float] = None, - repetition_context_size: Optional[int] = None, - top_p: float = 1.0, - logit_bias: Optional[Dict[int, float]] = None, -) -> str: + **kwargs, +) -> Union[str, Generator[str, None, None]]: """ - Generate text from the model. + Generate a complete response from the model. Args: model (nn.Module): The language model. tokenizer (PreTrainedTokenizer): The tokenizer. prompt (str): The string prompt. - temp (float): The temperature for sampling (default 0). - max_tokens (int): The maximum number of tokens (default 100). - verbose (bool): If ``True``, print tokens and timing information - (default ``False``). + max_tokens (int): The maximum number of tokens. Default: ``100``. + verbose (bool): If ``True``, print tokens and timing information. + Default: ``False``. formatter (Optional[Callable]): A function which takes a token and a probability and displays it. - repetition_penalty (float, optional): The penalty factor for repeating tokens. - repetition_context_size (int, optional): The number of tokens to consider for repetition penalty. + kwargs: The remaining options get passed to :func:`generate_step`. + See :func:`generate_step` for more details. """ if not isinstance(tokenizer, TokenizerWrapper): tokenizer = TokenizerWrapper(tokenizer) @@ -261,15 +302,7 @@ def generate( detokenizer.reset() for (token, prob), n in zip( - generate_step( - prompt_tokens, - model, - temp, - repetition_penalty, - repetition_context_size, - top_p, - logit_bias, - ), + generate_step(prompt_tokens, model, **kwargs), range(max_tokens), ): if n == 0: