diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index 12553b8a..954eb81c 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -155,23 +155,20 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size, current_input = expanded_prompts[idx] while len(current_tokens) < max_tokens: logits = model(current_input[None])[:, -1] - next_token = mx.argmax(logits, axis=-1) + probs = nn.softmax(logits, axis=-1) + next_token = mx.argmax(probs, axis=-1) token = next_token.item() current_tokens.append(token) tokens_generated += 1 - if token == tokenizer.eos_token_id: break - if (len(current_tokens) >= len(end_sequence) and mx.array_equal( mx.array(current_tokens[-len(end_sequence):]), end_sequence )): break - current_input = mx.concatenate([current_input, mx.array([token])]) - if len(current_tokens) % 32 == 0: mx.eval(current_input) mx.metal.clear_cache() @@ -182,7 +179,6 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size, max_tokens=max_tokens, sampler=lambda x: mx.argmax(x, axis=-1) ) - for token, _ in generator: current_tokens.append(token) tokens_generated += 1 @@ -276,6 +272,8 @@ def grpo_loss( print(f"Generation error: {e}") continue + mx.metal.clear_cache() + expanded_answers = [] expanded_prompts = [] for i in range(batch_size):