fix long prompt generations

This commit is contained in:
Awni Hannun
2024-10-09 09:07:38 -07:00
parent fca087be49
commit 3ddd7e9923

View File

@@ -239,8 +239,8 @@ def generate_step(
return y, logprobs.squeeze(0)
while y.size > prefill_step_size:
model(y[:prefill_step_size][None], cache=cache)
mx.eval([c.state for c in cache])
model(y[:prefill_step_size][None], cache=prompt_cache)
mx.eval([c.state for c in prompt_cache])
y = y[prefill_step_size:]
y, logprobs = _step(y)