mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
Make llm async eval less brittle
This commit is contained in:
@@ -246,10 +246,10 @@ def generate_step(
|
|||||||
|
|
||||||
y, logprobs = _step(y)
|
y, logprobs = _step(y)
|
||||||
|
|
||||||
mx.async_eval(y)
|
mx.async_eval(y, logprobs)
|
||||||
while True:
|
while True:
|
||||||
next_y, next_logprobs = _step(y)
|
next_y, next_logprobs = _step(y)
|
||||||
mx.async_eval(next_y)
|
mx.async_eval(next_y, next_logprobs)
|
||||||
yield y.item(), logprobs
|
yield y.item(), logprobs
|
||||||
y, logprobs = next_y, next_logprobs
|
y, logprobs = next_y, next_logprobs
|
||||||
|
|
||||||
@@ -348,7 +348,9 @@ def generate(
|
|||||||
if formatter:
|
if formatter:
|
||||||
# We have to finalize so that the prob corresponds to the last segment
|
# We have to finalize so that the prob corresponds to the last segment
|
||||||
detokenizer.finalize()
|
detokenizer.finalize()
|
||||||
formatter(detokenizer.last_segment, mx.exp(logprobs[token]).item())
|
with mx.stream(mx.default_stream(mx.cpu)):
|
||||||
|
prob = mx.exp(logprobs[token]).item()
|
||||||
|
formatter(detokenizer.last_segment, prob)
|
||||||
else:
|
else:
|
||||||
print(detokenizer.last_segment, end="", flush=True)
|
print(detokenizer.last_segment, end="", flush=True)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user