diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index e9e86e14..85b3528e 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -387,7 +387,8 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set test_ppl = math.exp(test_loss) - print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}, Rewards: {test_rewards[0]:.3f}, {test_rewards[1]:.3f}") + rewards_str = ", ".join([f"{k}: {v:.3f}" for k, v in test_rewards.items()]) + print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}, Rewards: {rewards_str}") else: test_loss = evaluate( model=model, diff --git a/llms/mlx_lm/tuner/grpo_reward_functions.py b/llms/mlx_lm/tuner/grpo_reward_functions.py index 59dfbfef..3b5c56b5 100644 --- a/llms/mlx_lm/tuner/grpo_reward_functions.py +++ b/llms/mlx_lm/tuner/grpo_reward_functions.py @@ -55,7 +55,7 @@ def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, * def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: if not completions: return [0.0] * len(prompts) - pattern = r"\n.*?\n\n*?" + pattern = r" .*? .*? " matches = [bool(re.search(pattern, r)) if r else False for r in completions] return [0.5 if match else 0.0 for match in matches] diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index 5ec3020a..d41bedce 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -10,7 +10,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -from .grpo_reward_functions import r1_accuracy_reward_func, r1_int_reward_func, r1_strict_format_reward_func, r1_soft_format_reward_func, r1_count_xml, RewardFunctions +from .grpo_reward_functions import r1_accuracy_reward_func, r1_int_reward_func, r1_strict_format_reward_func, r1_soft_format_reward_func, r1_count_xml,r1_extract_xml_answer, RewardFunctions from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients from ..utils import generate_step from ..models import cache @@ -173,18 +173,20 @@ def grpo_loss( try: if is_validation: completions = generate_grpo( - model, - prompt_tensor, - max_tokens, + model, + prompt_tensor, + max_tokens, tokenizer, - group_size + group_size, + temperature=temperature ) + model.train() else: completions = generate_grpo( - model, - prompt_tensor, - max_tokens, - tokenizer, + model, + prompt_tensor, + max_tokens, + tokenizer, group_size, is_training=True, temperature=temperature @@ -327,8 +329,13 @@ def grpo_loss( } if is_validation: - print(f"\nValidation sample generation:\n{all_completion_texts}\n") - print(f"Validation sample answer:\n{answer_text[-1]}\n") + print("\n=== Validation Sample Details ===") + print(f"\nšŸ“ Generation:\n{all_completion_texts[-1]}") + print("\n" + "="*10 + "\n") + print(f"\nāœ… Answer:\n{answer_text[-1]}") + print("\n" + "="*10 + "\n") + print(f"\nšŸ” Extracted Answer:\n{r1_extract_xml_answer(all_completion_texts[-1])}") + print("\n" + "="*30 + "\n") mx.metal.clear_cache() return loss, sequence_lengths.sum(), metrics @@ -396,7 +403,13 @@ def evaluate_grpo( max_seq_length: int, max_tokens: int, temperature: float, - reward_funcs: Optional[List[RewardFunctions]] = None, + reward_funcs: Optional[List[RewardFunctions]] = [ + r1_accuracy_reward_func, + r1_int_reward_func, + r1_strict_format_reward_func, + r1_soft_format_reward_func, + r1_count_xml + ], loss_fn: callable = grpo_loss, iterate_batches: callable = iterate_grpo_batches ): @@ -550,7 +563,7 @@ def train_grpo( val_time = time.perf_counter() - stop if rank == 0: val_metrics_str = ( - f"Val loss {val_loss:.8f}, " + f"Val loss {val_loss:.3f}, " f"Val total_rewards_mean {val_metrics['total_rewards_mean']:.3f}, " f"Val total_rewards_std {val_metrics['total_rewards_std']:.3f}, " f"Val grouped_rewards_mean {val_metrics['grouped_rewards_mean']:.3f}, " @@ -605,7 +618,7 @@ def train_grpo( if rank == 0: train_metrics_str = ( - f"Train loss {train_loss:.8f}, " + f"Train loss {train_loss:.3f}, " f"Total rewards mean {avg_metrics['total_rewards_mean']:.3f}, " f"Total rewards std {avg_metrics['total_rewards_std']:.3f}, " f"Grouped rewards mean {avg_metrics['grouped_rewards_mean']:.3f}, "