mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Make llm async eval less brittle (#1040)
* Make llm async eval less brittle * nit
This commit is contained in:
parent
1e0cda68c6
commit
c799133998
@ -246,10 +246,10 @@ def generate_step(
|
|||||||
|
|
||||||
y, logprobs = _step(y)
|
y, logprobs = _step(y)
|
||||||
|
|
||||||
mx.async_eval(y)
|
mx.async_eval(y, logprobs)
|
||||||
while True:
|
while True:
|
||||||
next_y, next_logprobs = _step(y)
|
next_y, next_logprobs = _step(y)
|
||||||
mx.async_eval(next_y)
|
mx.async_eval(next_y, next_logprobs)
|
||||||
yield y.item(), logprobs
|
yield y.item(), logprobs
|
||||||
y, logprobs = next_y, next_logprobs
|
y, logprobs = next_y, next_logprobs
|
||||||
|
|
||||||
@ -348,7 +348,9 @@ def generate(
|
|||||||
if formatter:
|
if formatter:
|
||||||
# We have to finalize so that the prob corresponds to the last segment
|
# We have to finalize so that the prob corresponds to the last segment
|
||||||
detokenizer.finalize()
|
detokenizer.finalize()
|
||||||
formatter(detokenizer.last_segment, mx.exp(logprobs[token]).item())
|
with mx.stream(mx.cpu):
|
||||||
|
prob = mx.exp(logprobs[token]).item()
|
||||||
|
formatter(detokenizer.last_segment, prob)
|
||||||
else:
|
else:
|
||||||
print(detokenizer.last_segment, end="", flush=True)
|
print(detokenizer.last_segment, end="", flush=True)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user