mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
last update, gn
This commit is contained in:
@@ -10,7 +10,7 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .grpo_reward_functions import r1_accuracy_reward_func, r1_int_reward_func, r1_strict_format_reward_func, r1_soft_format_reward_func, r1_count_xml, RewardFunctions
|
||||
from .grpo_reward_functions import r1_accuracy_reward_func, r1_int_reward_func, r1_strict_format_reward_func, r1_soft_format_reward_func, r1_count_xml,r1_extract_xml_answer, RewardFunctions
|
||||
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients
|
||||
from ..utils import generate_step
|
||||
from ..models import cache
|
||||
@@ -173,18 +173,20 @@ def grpo_loss(
|
||||
try:
|
||||
if is_validation:
|
||||
completions = generate_grpo(
|
||||
model,
|
||||
prompt_tensor,
|
||||
max_tokens,
|
||||
model,
|
||||
prompt_tensor,
|
||||
max_tokens,
|
||||
tokenizer,
|
||||
group_size
|
||||
group_size,
|
||||
temperature=temperature
|
||||
)
|
||||
model.train()
|
||||
else:
|
||||
completions = generate_grpo(
|
||||
model,
|
||||
prompt_tensor,
|
||||
max_tokens,
|
||||
tokenizer,
|
||||
model,
|
||||
prompt_tensor,
|
||||
max_tokens,
|
||||
tokenizer,
|
||||
group_size,
|
||||
is_training=True,
|
||||
temperature=temperature
|
||||
@@ -327,8 +329,13 @@ def grpo_loss(
|
||||
}
|
||||
|
||||
if is_validation:
|
||||
print(f"\nValidation sample generation:\n{all_completion_texts}\n")
|
||||
print(f"Validation sample answer:\n{answer_text[-1]}\n")
|
||||
print("\n=== Validation Sample Details ===")
|
||||
print(f"\n📝 Generation:\n{all_completion_texts[-1]}")
|
||||
print("\n" + "="*10 + "\n")
|
||||
print(f"\n✅ Answer:\n{answer_text[-1]}")
|
||||
print("\n" + "="*10 + "\n")
|
||||
print(f"\n🔍 Extracted Answer:\n{r1_extract_xml_answer(all_completion_texts[-1])}")
|
||||
print("\n" + "="*30 + "\n")
|
||||
mx.metal.clear_cache()
|
||||
|
||||
return loss, sequence_lengths.sum(), metrics
|
||||
@@ -396,7 +403,13 @@ def evaluate_grpo(
|
||||
max_seq_length: int,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
reward_funcs: Optional[List[RewardFunctions]] = None,
|
||||
reward_funcs: Optional[List[RewardFunctions]] = [
|
||||
r1_accuracy_reward_func,
|
||||
r1_int_reward_func,
|
||||
r1_strict_format_reward_func,
|
||||
r1_soft_format_reward_func,
|
||||
r1_count_xml
|
||||
],
|
||||
loss_fn: callable = grpo_loss,
|
||||
iterate_batches: callable = iterate_grpo_batches
|
||||
):
|
||||
@@ -550,7 +563,7 @@ def train_grpo(
|
||||
val_time = time.perf_counter() - stop
|
||||
if rank == 0:
|
||||
val_metrics_str = (
|
||||
f"Val loss {val_loss:.8f}, "
|
||||
f"Val loss {val_loss:.3f}, "
|
||||
f"Val total_rewards_mean {val_metrics['total_rewards_mean']:.3f}, "
|
||||
f"Val total_rewards_std {val_metrics['total_rewards_std']:.3f}, "
|
||||
f"Val grouped_rewards_mean {val_metrics['grouped_rewards_mean']:.3f}, "
|
||||
@@ -605,7 +618,7 @@ def train_grpo(
|
||||
|
||||
if rank == 0:
|
||||
train_metrics_str = (
|
||||
f"Train loss {train_loss:.8f}, "
|
||||
f"Train loss {train_loss:.3f}, "
|
||||
f"Total rewards mean {avg_metrics['total_rewards_mean']:.3f}, "
|
||||
f"Total rewards std {avg_metrics['total_rewards_std']:.3f}, "
|
||||
f"Grouped rewards mean {avg_metrics['grouped_rewards_mean']:.3f}, "
|
||||
|
||||
Reference in New Issue
Block a user