Accept mx.array type for prompt argument for stream_generate

This commit is contained in:
Neil Mehta 2024-11-26 13:53:03 -05:00
parent cfc29c29f4
commit 152d6b1e1e

View File

@ -298,7 +298,7 @@ def generate_step(
def stream_generate(
model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: Union[str, List[int]],
prompt: Union[str, mx.array, List[int]],
max_tokens: int = 100,
**kwargs,
) -> Generator[GenerationResponse, None, None]:
@ -308,7 +308,7 @@ def stream_generate(
Args:
model (nn.Module): The model to use for generation.
tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (Union[str, List[int]]): The input prompt string or integer tokens.
prompt (Union[str, mx.array, List[int]]): The input prompt string or integer tokens.
max_tokens (int): The maximum number of tokens. Default: ``100``.
kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details.
@ -320,7 +320,9 @@ def stream_generate(
if not isinstance(tokenizer, TokenizerWrapper):
tokenizer = TokenizerWrapper(tokenizer)
prompt = mx.array(prompt if isinstance(prompt, list) else tokenizer.encode(prompt))
if not isinstance(prompt, mx.array):
prompt = mx.array(prompt if isinstance(prompt, list) else tokenizer.encode(prompt))
detokenizer = tokenizer.detokenizer
with wired_limit(model, [generation_stream]):