Put prompt processing in same stream (#1122)

* put prompt processing in same stream

* patch
This commit is contained in:
Awni Hannun 2024-11-25 09:47:00 -08:00 committed by GitHub
parent a5e173802e
commit cfc29c29f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 8 additions and 7 deletions

View File

@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
__version__ = "0.20.0" __version__ = "0.20.1"

View File

@ -274,13 +274,14 @@ def generate_step(
y = sampler(logprobs) y = sampler(logprobs)
return y, logprobs.squeeze(0) return y, logprobs.squeeze(0)
while y.size > prefill_step_size: with mx.stream(generation_stream):
model(y[:prefill_step_size][None], cache=prompt_cache) while y.size > prefill_step_size:
mx.eval([c.state for c in prompt_cache]) model(y[:prefill_step_size][None], cache=prompt_cache)
y = y[prefill_step_size:] mx.eval([c.state for c in prompt_cache])
mx.metal.clear_cache() y = y[prefill_step_size:]
mx.metal.clear_cache()
y, logprobs = _step(y) y, logprobs = _step(y)
mx.async_eval(y, logprobs) mx.async_eval(y, logprobs)
n = 0 n = 0