mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-04 15:54:34 +08:00
Logprobs info to completion API (#806)
* Initial implementation * Fix handling of return_step_logits in return * Fixed OpenAI parameter expectations and logprob structure and datatypes * pre-commit black formatting * Remove unused parameter * fix log probs * fix colorize * nits in server * nits in server * Fix top_logprobs structure (a dict) and include tokens in logprobs response * nits * fix types --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -149,10 +149,11 @@ def generate_step(
|
||||
consider for repetition penalty. Default: ``20``.
|
||||
top_p (float, optional): Nulceus sampling, higher means model considers
|
||||
more less likely words.
|
||||
logit_bias (dictionary, optional): Additive logit bias.
|
||||
|
||||
Yields:
|
||||
Generator[Tuple[mx.array, mx.array]]: A generator producing
|
||||
one token and probability per call.
|
||||
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
|
||||
one token and a vector of log probabilities.
|
||||
"""
|
||||
|
||||
def sample(logits: mx.array) -> Tuple[mx.array, float]:
|
||||
@@ -160,7 +161,7 @@ def generate_step(
|
||||
indices = mx.array(list(logit_bias.keys()))
|
||||
values = mx.array(list(logit_bias.values()))
|
||||
logits[:, indices] += values
|
||||
softmax_logits = mx.softmax(logits)
|
||||
logprobs = logits - mx.logsumexp(logits)
|
||||
|
||||
if temp == 0:
|
||||
token = mx.argmax(logits, axis=-1)
|
||||
@@ -170,8 +171,7 @@ def generate_step(
|
||||
else:
|
||||
token = mx.random.categorical(logits * (1 / temp))
|
||||
|
||||
prob = softmax_logits[0, token]
|
||||
return token, prob
|
||||
return token, logprobs
|
||||
|
||||
if repetition_penalty and (
|
||||
repetition_penalty < 0 or not isinstance(repetition_penalty, float)
|
||||
@@ -202,24 +202,24 @@ def generate_step(
|
||||
logits = apply_repetition_penalty(
|
||||
logits, repetition_context, repetition_penalty
|
||||
)
|
||||
y, prob = sample(logits)
|
||||
y, logprobs = sample(logits)
|
||||
repetition_context.append(y.item())
|
||||
else:
|
||||
y, prob = sample(logits)
|
||||
y, logprobs = sample(logits)
|
||||
|
||||
if repetition_context_size:
|
||||
if len(repetition_context) > repetition_context_size:
|
||||
repetition_context = repetition_context[-repetition_context_size:]
|
||||
return y, prob
|
||||
return y, logprobs.squeeze(0)
|
||||
|
||||
y, p = _step(y)
|
||||
y, logprobs = _step(y)
|
||||
|
||||
mx.async_eval(y)
|
||||
while True:
|
||||
next_y, next_p = _step(y)
|
||||
next_y, next_logprobs = _step(y)
|
||||
mx.async_eval(next_y)
|
||||
yield y.item(), p
|
||||
y, p = next_y, next_p
|
||||
yield y.item(), logprobs
|
||||
y, logprobs = next_y, next_logprobs
|
||||
|
||||
|
||||
def stream_generate(
|
||||
@@ -249,7 +249,7 @@ def stream_generate(
|
||||
detokenizer = tokenizer.detokenizer
|
||||
|
||||
detokenizer.reset()
|
||||
for (token, prob), n in zip(
|
||||
for (token, _), n in zip(
|
||||
generate_step(prompt_tokens, model, **kwargs),
|
||||
range(max_tokens),
|
||||
):
|
||||
@@ -301,7 +301,7 @@ def generate(
|
||||
tic = time.perf_counter()
|
||||
detokenizer.reset()
|
||||
|
||||
for (token, prob), n in zip(
|
||||
for (token, logprobs), n in zip(
|
||||
generate_step(prompt_tokens, model, **kwargs),
|
||||
range(max_tokens),
|
||||
):
|
||||
@@ -316,7 +316,7 @@ def generate(
|
||||
if formatter:
|
||||
# We have to finalize so that the prob corresponds to the last segment
|
||||
detokenizer.finalize()
|
||||
formatter(detokenizer.last_segment, prob.item())
|
||||
formatter(detokenizer.last_segment, mx.exp(logprobs[token]).item())
|
||||
else:
|
||||
print(detokenizer.last_segment, end="", flush=True)
|
||||
|
||||
|
Reference in New Issue
Block a user