mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Accept mx.array type for prompt argument for stream_generate (#1125)
* Accept mx.array type for prompt argument for stream_generate * Fix formatting
This commit is contained in:
parent
cfc29c29f4
commit
cefe793ae0
@ -298,7 +298,7 @@ def generate_step(
|
||||
def stream_generate(
|
||||
model: nn.Module,
|
||||
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
||||
prompt: Union[str, List[int]],
|
||||
prompt: Union[str, mx.array, List[int]],
|
||||
max_tokens: int = 100,
|
||||
**kwargs,
|
||||
) -> Generator[GenerationResponse, None, None]:
|
||||
@ -308,7 +308,7 @@ def stream_generate(
|
||||
Args:
|
||||
model (nn.Module): The model to use for generation.
|
||||
tokenizer (PreTrainedTokenizer): The tokenizer.
|
||||
prompt (Union[str, List[int]]): The input prompt string or integer tokens.
|
||||
prompt (Union[str, mx.array, List[int]]): The input prompt string or integer tokens.
|
||||
max_tokens (int): The maximum number of tokens. Default: ``100``.
|
||||
kwargs: The remaining options get passed to :func:`generate_step`.
|
||||
See :func:`generate_step` for more details.
|
||||
@ -320,7 +320,11 @@ def stream_generate(
|
||||
if not isinstance(tokenizer, TokenizerWrapper):
|
||||
tokenizer = TokenizerWrapper(tokenizer)
|
||||
|
||||
prompt = mx.array(prompt if isinstance(prompt, list) else tokenizer.encode(prompt))
|
||||
if not isinstance(prompt, mx.array):
|
||||
prompt = mx.array(
|
||||
prompt if isinstance(prompt, list) else tokenizer.encode(prompt)
|
||||
)
|
||||
|
||||
detokenizer = tokenizer.detokenizer
|
||||
|
||||
with wired_limit(model, [generation_stream]):
|
||||
|
Loading…
Reference in New Issue
Block a user