mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Make llm async eval less brittle (#1040)
* Make llm async eval less brittle * nit
This commit is contained in:
parent
1e0cda68c6
commit
c799133998
@ -246,10 +246,10 @@ def generate_step(
|
||||
|
||||
y, logprobs = _step(y)
|
||||
|
||||
mx.async_eval(y)
|
||||
mx.async_eval(y, logprobs)
|
||||
while True:
|
||||
next_y, next_logprobs = _step(y)
|
||||
mx.async_eval(next_y)
|
||||
mx.async_eval(next_y, next_logprobs)
|
||||
yield y.item(), logprobs
|
||||
y, logprobs = next_y, next_logprobs
|
||||
|
||||
@ -348,7 +348,9 @@ def generate(
|
||||
if formatter:
|
||||
# We have to finalize so that the prob corresponds to the last segment
|
||||
detokenizer.finalize()
|
||||
formatter(detokenizer.last_segment, mx.exp(logprobs[token]).item())
|
||||
with mx.stream(mx.cpu):
|
||||
prob = mx.exp(logprobs[token]).item()
|
||||
formatter(detokenizer.last_segment, prob)
|
||||
else:
|
||||
print(detokenizer.last_segment, end="", flush=True)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user