diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index ab1ab605..4a1e6bbf 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -159,7 +159,6 @@ def get_per_token_logps(model, inputs, lengths): logits = model(inputs).astype(mx.float16) logits = logits[:, :-1, :] targets = inputs[:, 1:] - mx.eval(logits) per_token_logps = [] for i in range(logits.shape[0]): seq_len = int(lengths[i]) - 1 @@ -172,6 +171,7 @@ def get_per_token_logps(model, inputs, lengths): axis=-1 ).squeeze(-1) per_token_logps.append(token_log_probs) + mx.eval(logits) return per_token_logps