From e88f0fad4b82dbcd6f34501267477eb382807e0c Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Sun, 9 Mar 2025 00:18:33 +0100 Subject: [PATCH] clean up --- llms/mlx_lm/tuner/grpo_trainer.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index d65cf4b6..31b6dc58 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -121,9 +121,9 @@ def generate_grpo( prompt_tokens, max_tokens: int, group_size: int, - end_token: str = "", - temperature: float = 0.8, - batch_size: int = 1, + temperature: float, + batch_size: int, + end_token: str = "" ): try: end_sequence = mx.array(tokenizer.encode(end_token)) @@ -239,7 +239,6 @@ def grpo_loss( expanded_answers = [] expanded_prompts = [] - unique_prompt_indices = sorted(set(batch_indices)) grouped_completions = {idx: [] for idx in unique_prompt_indices} @@ -262,7 +261,6 @@ def grpo_loss( all_completions = ordered_completions all_completion_texts = ordered_completion_texts batch_indices = ordered_batch_indices - max_length = max(ids.shape[0] for ids in all_completions) padded_completions = [] attention_masks = [] @@ -617,11 +615,8 @@ def train_grpo( state = [model.state, optimizer.state] def step(batch): - # Extract prompt tokens from the batch prompt_tokens, targets, prompt_lens, target_lens = batch - # First, generate completions without gradient tracking - # The model will be frozen during this call all_completions, all_completion_texts, batch_indices = generate_grpo( model=model, tokenizer=tokenizer, @@ -630,9 +625,7 @@ def train_grpo( group_size=args.group_size, temperature=args.temperature ) - - # Now calculate loss and gradients with pre-generated completions - # We need to update loss_fn to accept these pre-generated completions + (loss, toks, metrics), grad = loss_value_and_grad( model, tokenizer=tokenizer,