clear cache during prompt processing (#1027)

This commit is contained in:
Awni Hannun 2024-10-09 16:48:32 -07:00 committed by GitHub
parent b7373cb44f
commit 4360e7ccec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -242,6 +242,7 @@ def generate_step(
model(y[:prefill_step_size][None], cache=prompt_cache) model(y[:prefill_step_size][None], cache=prompt_cache)
mx.eval([c.state for c in prompt_cache]) mx.eval([c.state for c in prompt_cache])
y = y[prefill_step_size:] y = y[prefill_step_size:]
mx.metal.clear_cache()
y, logprobs = _step(y) y, logprobs = _step(y)