mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
return logprobs
This commit is contained in:
@@ -314,10 +314,11 @@ def generate(
|
||||
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
||||
prompt: str,
|
||||
max_tokens: int = 100,
|
||||
logprobs_seq: bool = False,
|
||||
verbose: bool = False,
|
||||
formatter: Optional[Callable] = None,
|
||||
**kwargs,
|
||||
) -> Union[str, Generator[str, None, None]]:
|
||||
) -> Union[str, Generator[str, None, None], list[mx.array]]:
|
||||
"""
|
||||
Generate a complete response from the model.
|
||||
|
||||
@@ -326,6 +327,8 @@ def generate(
|
||||
tokenizer (PreTrainedTokenizer): The tokenizer.
|
||||
prompt (str): The string prompt.
|
||||
max_tokens (int): The maximum number of tokens. Default: ``100``.
|
||||
logprobs_seq (bool): Return logprobs for sampled tokens across
|
||||
generated sequence. Default: ``False``.
|
||||
verbose (bool): If ``True``, print tokens and timing information.
|
||||
Default: ``False``.
|
||||
formatter (Optional[Callable]): A function which takes a token and a
|
||||
@@ -346,6 +349,7 @@ def generate(
|
||||
tic = time.perf_counter()
|
||||
detokenizer.reset()
|
||||
|
||||
all_logprobs = []
|
||||
for (token, logprobs), n in zip(
|
||||
generate_step(prompt_tokens, model, **kwargs),
|
||||
range(max_tokens),
|
||||
@@ -355,7 +359,10 @@ def generate(
|
||||
tic = time.perf_counter()
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
|
||||
detokenizer.add_token(token)
|
||||
if logprobs_seq:
|
||||
all_logprobs.append(logprobs)
|
||||
|
||||
if verbose:
|
||||
if formatter:
|
||||
@@ -364,7 +371,7 @@ def generate(
|
||||
formatter(detokenizer.last_segment, mx.exp(logprobs[token]).item())
|
||||
else:
|
||||
print(detokenizer.last_segment, end="", flush=True)
|
||||
|
||||
|
||||
token_count = n + 1
|
||||
detokenizer.finalize()
|
||||
|
||||
@@ -382,8 +389,11 @@ def generate(
|
||||
peak_mem = mx.metal.get_peak_memory() / 2**30
|
||||
print(f"Peak memory: {peak_mem:.3f} GB")
|
||||
|
||||
return detokenizer.text
|
||||
|
||||
if logprobs_seq:
|
||||
return detokenizer.text, all_logprobs
|
||||
else:
|
||||
return detokenizer.text
|
||||
|
||||
|
||||
def load_config(model_path: Path) -> dict:
|
||||
try:
|
||||
|
Reference in New Issue
Block a user