diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index 28546cae..098580b1 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -241,6 +241,7 @@ def generate_grpo( if i >= max_tokens - 1: break + mx.metal.clear_cache() mx.eval(current_tokens) if current_tokens: