Logprobs info to completion API (#806)

* Initial implementation

* Fix handling of return_step_logits in return

* Fixed OpenAI parameter expectations and logprob structure and datatypes

* pre-commit black formatting

* Remove unused parameter

* fix log probs

* fix colorize

* nits in server

* nits in server

* Fix top_logprobs structure (a dict) and include tokens in logprobs response

* nits

* fix types

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Chime Ogbuji
2024-06-23 13:35:13 -04:00
committed by GitHub
parent a7598e9456
commit 1d701a1831
3 changed files with 94 additions and 43 deletions

View File

@@ -149,10 +149,11 @@ def generate_step(
consider for repetition penalty. Default: ``20``.
top_p (float, optional): Nulceus sampling, higher means model considers
more less likely words.
logit_bias (dictionary, optional): Additive logit bias.
Yields:
Generator[Tuple[mx.array, mx.array]]: A generator producing
one token and probability per call.
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
one token and a vector of log probabilities.
"""
def sample(logits: mx.array) -> Tuple[mx.array, float]:
@@ -160,7 +161,7 @@ def generate_step(
indices = mx.array(list(logit_bias.keys()))
values = mx.array(list(logit_bias.values()))
logits[:, indices] += values
softmax_logits = mx.softmax(logits)
logprobs = logits - mx.logsumexp(logits)
if temp == 0:
token = mx.argmax(logits, axis=-1)
@@ -170,8 +171,7 @@ def generate_step(
else:
token = mx.random.categorical(logits * (1 / temp))
prob = softmax_logits[0, token]
return token, prob
return token, logprobs
if repetition_penalty and (
repetition_penalty < 0 or not isinstance(repetition_penalty, float)
@@ -202,24 +202,24 @@ def generate_step(
logits = apply_repetition_penalty(
logits, repetition_context, repetition_penalty
)
y, prob = sample(logits)
y, logprobs = sample(logits)
repetition_context.append(y.item())
else:
y, prob = sample(logits)
y, logprobs = sample(logits)
if repetition_context_size:
if len(repetition_context) > repetition_context_size:
repetition_context = repetition_context[-repetition_context_size:]
return y, prob
return y, logprobs.squeeze(0)
y, p = _step(y)
y, logprobs = _step(y)
mx.async_eval(y)
while True:
next_y, next_p = _step(y)
next_y, next_logprobs = _step(y)
mx.async_eval(next_y)
yield y.item(), p
y, p = next_y, next_p
yield y.item(), logprobs
y, logprobs = next_y, next_logprobs
def stream_generate(
@@ -249,7 +249,7 @@ def stream_generate(
detokenizer = tokenizer.detokenizer
detokenizer.reset()
for (token, prob), n in zip(
for (token, _), n in zip(
generate_step(prompt_tokens, model, **kwargs),
range(max_tokens),
):
@@ -301,7 +301,7 @@ def generate(
tic = time.perf_counter()
detokenizer.reset()
for (token, prob), n in zip(
for (token, logprobs), n in zip(
generate_step(prompt_tokens, model, **kwargs),
range(max_tokens),
):
@@ -316,7 +316,7 @@ def generate(
if formatter:
# We have to finalize so that the prob corresponds to the last segment
detokenizer.finalize()
formatter(detokenizer.last_segment, prob.item())
formatter(detokenizer.last_segment, mx.exp(logprobs[token]).item())
else:
print(detokenizer.last_segment, end="", flush=True)