diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 0e2f7af7..f439ca99 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -298,7 +298,7 @@ def generate_step( def stream_generate( model: nn.Module, tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], - prompt: Union[str, List[int]], + prompt: Union[str, mx.array, List[int]], max_tokens: int = 100, **kwargs, ) -> Generator[GenerationResponse, None, None]: @@ -308,7 +308,7 @@ def stream_generate( Args: model (nn.Module): The model to use for generation. tokenizer (PreTrainedTokenizer): The tokenizer. - prompt (Union[str, List[int]]): The input prompt string or integer tokens. + prompt (Union[str, mx.array, List[int]]): The input prompt string or integer tokens. max_tokens (int): The maximum number of tokens. Default: ``100``. kwargs: The remaining options get passed to :func:`generate_step`. See :func:`generate_step` for more details. @@ -320,7 +320,11 @@ def stream_generate( if not isinstance(tokenizer, TokenizerWrapper): tokenizer = TokenizerWrapper(tokenizer) - prompt = mx.array(prompt if isinstance(prompt, list) else tokenizer.encode(prompt)) + if not isinstance(prompt, mx.array): + prompt = mx.array( + prompt if isinstance(prompt, list) else tokenizer.encode(prompt) + ) + detokenizer = tokenizer.detokenizer with wired_limit(model, [generation_stream]):