clear cache every now and then

This commit is contained in:
Awni Hannun 2024-11-01 13:54:05 -07:00
parent 8160e0c4e5
commit c102d528ae

View File

@ -310,10 +310,14 @@ def generate_step(
y, logprobs = _step(y)
mx.async_eval(y, logprobs)
n = 0
while True:
next_y, next_logprobs = _step(y)
mx.async_eval(next_y, next_logprobs)
yield y.item(), logprobs
if n % 256 == 0:
mx.metal.clear_cache()
n += 1
y, logprobs = next_y, next_logprobs