Make llm async eval less brittle (#1040)

* Make llm async eval less brittle

* nit
This commit is contained in:
Awni Hannun 2024-10-14 10:25:24 -07:00 committed by GitHub
parent 1e0cda68c6
commit c799133998
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -246,10 +246,10 @@ def generate_step(
y, logprobs = _step(y)
mx.async_eval(y)
mx.async_eval(y, logprobs)
while True:
next_y, next_logprobs = _step(y)
mx.async_eval(next_y)
mx.async_eval(next_y, next_logprobs)
yield y.item(), logprobs
y, logprobs = next_y, next_logprobs
@ -348,7 +348,9 @@ def generate(
if formatter:
# We have to finalize so that the prob corresponds to the last segment
detokenizer.finalize()
formatter(detokenizer.last_segment, mx.exp(logprobs[token]).item())
with mx.stream(mx.cpu):
prob = mx.exp(logprobs[token]).item()
formatter(detokenizer.last_segment, prob)
else:
print(detokenizer.last_segment, end="", flush=True)