diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index 351ba9de..70ac2eda 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -76,7 +76,6 @@ def generate_grpo( # Process in batches for batch_start in range(0, total_samples, batch_size): batch_end = min(batch_start + batch_size, total_samples) - batch_results = [] if is_training: # Training mode with batched processing @@ -92,26 +91,46 @@ def generate_grpo( # Track tokens for each sequence in the batch batch_tokens = [[] for _ in range(batch_end - batch_start)] - active_indices = list(range(batch_end - batch_start)) + + # Initial token generation for all sequences in batch + for i in range(len(batch_logits)): + logits_temp = batch_logits[i] / temperature + next_token = mx.random.categorical(logits_temp) + token = next_token.item() + mx.eval(logits_temp, next_token, token) + batch_tokens[i].append(token) + + # Check if this token already completes the sequence + if token == tokenizer.eos_token_id: + continue + else: + # Set up for next token + current_input = mx.array([token]) + batch_logits[i] = model(current_input[None], cache=prompt_caches[i])[:, -1] + + mx.eval(batch_logits) + active_indices = [i for i, tokens in enumerate(batch_tokens) if tokens[-1] != tokenizer.eos_token_id and len(tokens) < max_tokens] # Generate tokens until all sequences are complete while active_indices and max(len(tokens) for tokens in batch_tokens) < max_tokens: next_active = [] for idx in active_indices: logits_temp = batch_logits[idx] / temperature - probs = nn.softmax(logits_temp, axis=-1) next_token = mx.random.categorical(logits_temp) token = next_token.item() - - test_sequence = batch_tokens[idx] + [token] - is_end = (len(test_sequence) >= len(end_sequence) and - mx.array_equal( - mx.array(test_sequence[-len(end_sequence):]), - end_sequence - )) - + mx.eval(logits_temp, next_token, token) batch_tokens[idx].append(token) + # Check for end sequence + if len(batch_tokens[idx]) >= len(end_sequence): + test_sequence = batch_tokens[idx][-len(end_sequence):] + is_end = mx.array_equal( + mx.array(test_sequence), + end_sequence + ) + else: + is_end = False + if is_end or token == tokenizer.eos_token_id or len(batch_tokens[idx]) >= max_tokens: # This sequence is done pass @@ -123,12 +142,31 @@ def generate_grpo( mx.eval([batch_logits[idx] for idx in next_active]) active_indices = next_active + + # Clear caches after processing this batch + for pc in prompt_caches: + del pc # Add batch results to overall results for tokens in batch_tokens: if tokens: - results.append(mx.array(tokens)) - + # Filter out any special tokens that might appear after the end token + if len(tokens) >= len(end_sequence): + for i in range(len(tokens) - len(end_sequence) + 1): + if mx.array_equal( + mx.array(tokens[i:i+len(end_sequence)]), + end_sequence + ): + tokens = tokens[:i+len(end_sequence)] + break + + # Filter out EOS token if it's the last token + if tokens and tokens[-1] == tokenizer.eos_token_id: + tokens = tokens[:-1] + + # Only add non-empty token lists + if tokens: + results.append(mx.array(tokens)) else: # Non-training mode with batched processing for idx in range(batch_start, batch_end): @@ -158,7 +196,6 @@ def generate_grpo( results.append(mx.array(current_tokens)) mx.metal.clear_cache() - mx.eval(results) return results @@ -267,14 +304,7 @@ def grpo_loss( # If we didn't generate any completions, return early if not all_completions: - print("No completions were generated. Returning zero loss.") - dummy_loss = mx.zeros(1) - dummy_metrics = { - 'total_rewards_mean': mx.zeros(1), - 'total_rewards_std': mx.zeros(1), - 'kl': mx.zeros(1) - } - return dummy_loss, mx.array(0), dummy_metrics + raise ValueError("No completions were generated. Please check your model and inputs.") # Create expanded prompts and answers based on actual generated completions expanded_answers = [] @@ -453,11 +483,24 @@ def grpo_loss( if is_validation and all_completion_texts: print("\n=== Validation Sample Details ===") + + # Print the input context (prompt) + last_prompt_idx = batch_indices[-1] if batch_indices else 0 + + if last_prompt_idx < len(prompt_text): + print(f"\nšŸ“‹ Raw Prompt:\n{prompt_text[last_prompt_idx]}") + print("\n" + "="*10 + "\n") + + # Get the actual tokenized prompt that was fed to the model + if last_prompt_idx < len(prompt_tokens): + actual_prompt = tokenizer.decode(prompt_tokens[last_prompt_idx]) + print(f"\nšŸ”„ Model Input:\n{actual_prompt}") + print("\n" + "="*10 + "\n") + print(f"\nšŸ“ Generation:\n{all_completion_texts[-1]}") print("\n" + "="*10 + "\n") # Make sure we have a valid index for answer_text - last_prompt_idx = batch_indices[-1] if batch_indices else 0 if last_prompt_idx < len(answer_text): print(f"\nāœ… Answer:\n{answer_text[last_prompt_idx]}") print("\n" + "="*10 + "\n")