diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py
index e9e86e14..85b3528e 100644
--- a/llms/mlx_lm/lora.py
+++ b/llms/mlx_lm/lora.py
@@ -387,7 +387,8 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set
test_ppl = math.exp(test_loss)
- print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}, Rewards: {test_rewards[0]:.3f}, {test_rewards[1]:.3f}")
+ rewards_str = ", ".join([f"{k}: {v:.3f}" for k, v in test_rewards.items()])
+ print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}, Rewards: {rewards_str}")
else:
test_loss = evaluate(
model=model,
diff --git a/llms/mlx_lm/tuner/grpo_reward_functions.py b/llms/mlx_lm/tuner/grpo_reward_functions.py
index 59dfbfef..3b5c56b5 100644
--- a/llms/mlx_lm/tuner/grpo_reward_functions.py
+++ b/llms/mlx_lm/tuner/grpo_reward_functions.py
@@ -55,7 +55,7 @@ def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, *
def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
if not completions:
return [0.0] * len(prompts)
- pattern = r"\n.*?\n\n*?"
+ pattern = r" .*? .*? "
matches = [bool(re.search(pattern, r)) if r else False for r in completions]
return [0.5 if match else 0.0 for match in matches]
diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py
index 5ec3020a..d41bedce 100644
--- a/llms/mlx_lm/tuner/grpo_trainer.py
+++ b/llms/mlx_lm/tuner/grpo_trainer.py
@@ -10,7 +10,7 @@ import mlx.core as mx
import mlx.nn as nn
import numpy as np
-from .grpo_reward_functions import r1_accuracy_reward_func, r1_int_reward_func, r1_strict_format_reward_func, r1_soft_format_reward_func, r1_count_xml, RewardFunctions
+from .grpo_reward_functions import r1_accuracy_reward_func, r1_int_reward_func, r1_strict_format_reward_func, r1_soft_format_reward_func, r1_count_xml,r1_extract_xml_answer, RewardFunctions
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients
from ..utils import generate_step
from ..models import cache
@@ -173,18 +173,20 @@ def grpo_loss(
try:
if is_validation:
completions = generate_grpo(
- model,
- prompt_tensor,
- max_tokens,
+ model,
+ prompt_tensor,
+ max_tokens,
tokenizer,
- group_size
+ group_size,
+ temperature=temperature
)
+ model.train()
else:
completions = generate_grpo(
- model,
- prompt_tensor,
- max_tokens,
- tokenizer,
+ model,
+ prompt_tensor,
+ max_tokens,
+ tokenizer,
group_size,
is_training=True,
temperature=temperature
@@ -327,8 +329,13 @@ def grpo_loss(
}
if is_validation:
- print(f"\nValidation sample generation:\n{all_completion_texts}\n")
- print(f"Validation sample answer:\n{answer_text[-1]}\n")
+ print("\n=== Validation Sample Details ===")
+ print(f"\nš Generation:\n{all_completion_texts[-1]}")
+ print("\n" + "="*10 + "\n")
+ print(f"\nā
Answer:\n{answer_text[-1]}")
+ print("\n" + "="*10 + "\n")
+ print(f"\nš Extracted Answer:\n{r1_extract_xml_answer(all_completion_texts[-1])}")
+ print("\n" + "="*30 + "\n")
mx.metal.clear_cache()
return loss, sequence_lengths.sum(), metrics
@@ -396,7 +403,13 @@ def evaluate_grpo(
max_seq_length: int,
max_tokens: int,
temperature: float,
- reward_funcs: Optional[List[RewardFunctions]] = None,
+ reward_funcs: Optional[List[RewardFunctions]] = [
+ r1_accuracy_reward_func,
+ r1_int_reward_func,
+ r1_strict_format_reward_func,
+ r1_soft_format_reward_func,
+ r1_count_xml
+ ],
loss_fn: callable = grpo_loss,
iterate_batches: callable = iterate_grpo_batches
):
@@ -550,7 +563,7 @@ def train_grpo(
val_time = time.perf_counter() - stop
if rank == 0:
val_metrics_str = (
- f"Val loss {val_loss:.8f}, "
+ f"Val loss {val_loss:.3f}, "
f"Val total_rewards_mean {val_metrics['total_rewards_mean']:.3f}, "
f"Val total_rewards_std {val_metrics['total_rewards_std']:.3f}, "
f"Val grouped_rewards_mean {val_metrics['grouped_rewards_mean']:.3f}, "
@@ -605,7 +618,7 @@ def train_grpo(
if rank == 0:
train_metrics_str = (
- f"Train loss {train_loss:.8f}, "
+ f"Train loss {train_loss:.3f}, "
f"Total rewards mean {avg_metrics['total_rewards_mean']:.3f}, "
f"Total rewards std {avg_metrics['total_rewards_std']:.3f}, "
f"Grouped rewards mean {avg_metrics['grouped_rewards_mean']:.3f}, "