put prompt processing in same stream

This commit is contained in:
Awni Hannun 2024-11-25 08:06:12 -08:00
parent adaab81029
commit 3586c876aa

View File

@ -274,13 +274,14 @@ def generate_step(
y = sampler(logprobs)
return y, logprobs.squeeze(0)
while y.size > prefill_step_size:
model(y[:prefill_step_size][None], cache=prompt_cache)
mx.eval([c.state for c in prompt_cache])
y = y[prefill_step_size:]
mx.metal.clear_cache()
with mx.stream(generation_stream):
while y.size > prefill_step_size:
model(y[:prefill_step_size][None], cache=prompt_cache)
mx.eval([c.state for c in prompt_cache])
y = y[prefill_step_size:]
mx.metal.clear_cache()
y, logprobs = _step(y)
y, logprobs = _step(y)
mx.async_eval(y, logprobs)
n = 0