Clear cache every now and then (#1081)

* clear cache every now and then

* don't need user arg anymore
This commit is contained in:
Awni Hannun 2024-11-01 14:15:32 -07:00 committed by GitHub
parent 8160e0c4e5
commit e510987870
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 4 additions and 9 deletions

View File

@ -90,12 +90,6 @@ def setup_arg_parser():
action="store_true", action="store_true",
help="Colorize output based on T[0] probability", help="Colorize output based on T[0] probability",
) )
parser.add_argument(
"--cache-limit-gb",
type=int,
default=None,
help="Set the MLX cache limit in GB",
)
parser.add_argument( parser.add_argument(
"--max-kv-size", "--max-kv-size",
type=int, type=int,
@ -164,9 +158,6 @@ def main():
mx.random.seed(args.seed) mx.random.seed(args.seed)
if args.cache_limit_gb is not None:
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
# Load the prompt cache and metadata if a cache file is provided # Load the prompt cache and metadata if a cache file is provided
using_cache = args.prompt_cache_file is not None using_cache = args.prompt_cache_file is not None
if using_cache: if using_cache:

View File

@ -310,10 +310,14 @@ def generate_step(
y, logprobs = _step(y) y, logprobs = _step(y)
mx.async_eval(y, logprobs) mx.async_eval(y, logprobs)
n = 0
while True: while True:
next_y, next_logprobs = _step(y) next_y, next_logprobs = _step(y)
mx.async_eval(next_y, next_logprobs) mx.async_eval(next_y, next_logprobs)
yield y.item(), logprobs yield y.item(), logprobs
if n % 256 == 0:
mx.metal.clear_cache()
n += 1
y, logprobs = next_y, next_logprobs y, logprobs = next_y, next_logprobs