mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Clear cache every now and then (#1081)
* clear cache every now and then * don't need user arg anymore
This commit is contained in:
parent
8160e0c4e5
commit
e510987870
@ -90,12 +90,6 @@ def setup_arg_parser():
|
||||
action="store_true",
|
||||
help="Colorize output based on T[0] probability",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache-limit-gb",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Set the MLX cache limit in GB",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-kv-size",
|
||||
type=int,
|
||||
@ -164,9 +158,6 @@ def main():
|
||||
|
||||
mx.random.seed(args.seed)
|
||||
|
||||
if args.cache_limit_gb is not None:
|
||||
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
|
||||
|
||||
# Load the prompt cache and metadata if a cache file is provided
|
||||
using_cache = args.prompt_cache_file is not None
|
||||
if using_cache:
|
||||
|
@ -310,10 +310,14 @@ def generate_step(
|
||||
y, logprobs = _step(y)
|
||||
|
||||
mx.async_eval(y, logprobs)
|
||||
n = 0
|
||||
while True:
|
||||
next_y, next_logprobs = _step(y)
|
||||
mx.async_eval(next_y, next_logprobs)
|
||||
yield y.item(), logprobs
|
||||
if n % 256 == 0:
|
||||
mx.metal.clear_cache()
|
||||
n += 1
|
||||
y, logprobs = next_y, next_logprobs
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user