return logprobs

This commit is contained in:
farris
2024-10-05 20:59:18 -07:00
parent 9bb2dd62f3
commit c7508270c3

View File

@@ -314,10 +314,11 @@ def generate(
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: str,
max_tokens: int = 100,
logprobs_seq: bool = False,
verbose: bool = False,
formatter: Optional[Callable] = None,
**kwargs,
) -> Union[str, Generator[str, None, None]]:
) -> Union[str, Generator[str, None, None], list[mx.array]]:
"""
Generate a complete response from the model.
@@ -326,6 +327,8 @@ def generate(
tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (str): The string prompt.
max_tokens (int): The maximum number of tokens. Default: ``100``.
logprobs_seq (bool): Return logprobs for sampled tokens across
generated sequence. Default: ``False``.
verbose (bool): If ``True``, print tokens and timing information.
Default: ``False``.
formatter (Optional[Callable]): A function which takes a token and a
@@ -346,6 +349,7 @@ def generate(
tic = time.perf_counter()
detokenizer.reset()
all_logprobs = []
for (token, logprobs), n in zip(
generate_step(prompt_tokens, model, **kwargs),
range(max_tokens),
@@ -355,7 +359,10 @@ def generate(
tic = time.perf_counter()
if token == tokenizer.eos_token_id:
break
detokenizer.add_token(token)
if logprobs_seq:
all_logprobs.append(logprobs)
if verbose:
if formatter:
@@ -364,7 +371,7 @@ def generate(
formatter(detokenizer.last_segment, mx.exp(logprobs[token]).item())
else:
print(detokenizer.last_segment, end="", flush=True)
token_count = n + 1
detokenizer.finalize()
@@ -382,8 +389,11 @@ def generate(
peak_mem = mx.metal.get_peak_memory() / 2**30
print(f"Peak memory: {peak_mem:.3f} GB")
return detokenizer.text
if logprobs_seq:
return detokenizer.text, all_logprobs
else:
return detokenizer.text
def load_config(model_path: Path) -> dict:
try: