Accept mx.array type for prompt argument for stream_generate (#1125)

* Accept mx.array type for prompt argument for stream_generate

* Fix formatting
This commit is contained in:
Neil Mehta 2024-11-26 19:51:55 -05:00 committed by GitHub
parent cfc29c29f4
commit cefe793ae0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -298,7 +298,7 @@ def generate_step(
def stream_generate(
model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: Union[str, List[int]],
prompt: Union[str, mx.array, List[int]],
max_tokens: int = 100,
**kwargs,
) -> Generator[GenerationResponse, None, None]:
@ -308,7 +308,7 @@ def stream_generate(
Args:
model (nn.Module): The model to use for generation.
tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (Union[str, List[int]]): The input prompt string or integer tokens.
prompt (Union[str, mx.array, List[int]]): The input prompt string or integer tokens.
max_tokens (int): The maximum number of tokens. Default: ``100``.
kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details.
@ -320,7 +320,11 @@ def stream_generate(
if not isinstance(tokenizer, TokenizerWrapper):
tokenizer = TokenizerWrapper(tokenizer)
prompt = mx.array(prompt if isinstance(prompt, list) else tokenizer.encode(prompt))
if not isinstance(prompt, mx.array):
prompt = mx.array(
prompt if isinstance(prompt, list) else tokenizer.encode(prompt)
)
detokenizer = tokenizer.detokenizer
with wired_limit(model, [generation_stream]):