mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 02:33:23 +08:00
last update, gn
This commit is contained in:
parent
e4eac9c97b
commit
53185c7f3d
@ -387,7 +387,8 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set
|
|||||||
|
|
||||||
test_ppl = math.exp(test_loss)
|
test_ppl = math.exp(test_loss)
|
||||||
|
|
||||||
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}, Rewards: {test_rewards[0]:.3f}, {test_rewards[1]:.3f}")
|
rewards_str = ", ".join([f"{k}: {v:.3f}" for k, v in test_rewards.items()])
|
||||||
|
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}, Rewards: {rewards_str}")
|
||||||
else:
|
else:
|
||||||
test_loss = evaluate(
|
test_loss = evaluate(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -55,7 +55,7 @@ def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, *
|
|||||||
def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||||
if not completions:
|
if not completions:
|
||||||
return [0.0] * len(prompts)
|
return [0.0] * len(prompts)
|
||||||
pattern = r"<think>\n.*?\n</think>\n<answer>*?</answer>"
|
pattern = r"<think> .*? </think><answer> .*? </answer>"
|
||||||
matches = [bool(re.search(pattern, r)) if r else False for r in completions]
|
matches = [bool(re.search(pattern, r)) if r else False for r in completions]
|
||||||
return [0.5 if match else 0.0 for match in matches]
|
return [0.5 if match else 0.0 for match in matches]
|
||||||
|
|
||||||
|
@ -10,7 +10,7 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .grpo_reward_functions import r1_accuracy_reward_func, r1_int_reward_func, r1_strict_format_reward_func, r1_soft_format_reward_func, r1_count_xml, RewardFunctions
|
from .grpo_reward_functions import r1_accuracy_reward_func, r1_int_reward_func, r1_strict_format_reward_func, r1_soft_format_reward_func, r1_count_xml,r1_extract_xml_answer, RewardFunctions
|
||||||
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients
|
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients
|
||||||
from ..utils import generate_step
|
from ..utils import generate_step
|
||||||
from ..models import cache
|
from ..models import cache
|
||||||
@ -177,8 +177,10 @@ def grpo_loss(
|
|||||||
prompt_tensor,
|
prompt_tensor,
|
||||||
max_tokens,
|
max_tokens,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
group_size
|
group_size,
|
||||||
|
temperature=temperature
|
||||||
)
|
)
|
||||||
|
model.train()
|
||||||
else:
|
else:
|
||||||
completions = generate_grpo(
|
completions = generate_grpo(
|
||||||
model,
|
model,
|
||||||
@ -327,8 +329,13 @@ def grpo_loss(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if is_validation:
|
if is_validation:
|
||||||
print(f"\nValidation sample generation:\n{all_completion_texts}\n")
|
print("\n=== Validation Sample Details ===")
|
||||||
print(f"Validation sample answer:\n{answer_text[-1]}\n")
|
print(f"\n📝 Generation:\n{all_completion_texts[-1]}")
|
||||||
|
print("\n" + "="*10 + "\n")
|
||||||
|
print(f"\n✅ Answer:\n{answer_text[-1]}")
|
||||||
|
print("\n" + "="*10 + "\n")
|
||||||
|
print(f"\n🔍 Extracted Answer:\n{r1_extract_xml_answer(all_completion_texts[-1])}")
|
||||||
|
print("\n" + "="*30 + "\n")
|
||||||
mx.metal.clear_cache()
|
mx.metal.clear_cache()
|
||||||
|
|
||||||
return loss, sequence_lengths.sum(), metrics
|
return loss, sequence_lengths.sum(), metrics
|
||||||
@ -396,7 +403,13 @@ def evaluate_grpo(
|
|||||||
max_seq_length: int,
|
max_seq_length: int,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
temperature: float,
|
temperature: float,
|
||||||
reward_funcs: Optional[List[RewardFunctions]] = None,
|
reward_funcs: Optional[List[RewardFunctions]] = [
|
||||||
|
r1_accuracy_reward_func,
|
||||||
|
r1_int_reward_func,
|
||||||
|
r1_strict_format_reward_func,
|
||||||
|
r1_soft_format_reward_func,
|
||||||
|
r1_count_xml
|
||||||
|
],
|
||||||
loss_fn: callable = grpo_loss,
|
loss_fn: callable = grpo_loss,
|
||||||
iterate_batches: callable = iterate_grpo_batches
|
iterate_batches: callable = iterate_grpo_batches
|
||||||
):
|
):
|
||||||
@ -550,7 +563,7 @@ def train_grpo(
|
|||||||
val_time = time.perf_counter() - stop
|
val_time = time.perf_counter() - stop
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
val_metrics_str = (
|
val_metrics_str = (
|
||||||
f"Val loss {val_loss:.8f}, "
|
f"Val loss {val_loss:.3f}, "
|
||||||
f"Val total_rewards_mean {val_metrics['total_rewards_mean']:.3f}, "
|
f"Val total_rewards_mean {val_metrics['total_rewards_mean']:.3f}, "
|
||||||
f"Val total_rewards_std {val_metrics['total_rewards_std']:.3f}, "
|
f"Val total_rewards_std {val_metrics['total_rewards_std']:.3f}, "
|
||||||
f"Val grouped_rewards_mean {val_metrics['grouped_rewards_mean']:.3f}, "
|
f"Val grouped_rewards_mean {val_metrics['grouped_rewards_mean']:.3f}, "
|
||||||
@ -605,7 +618,7 @@ def train_grpo(
|
|||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
train_metrics_str = (
|
train_metrics_str = (
|
||||||
f"Train loss {train_loss:.8f}, "
|
f"Train loss {train_loss:.3f}, "
|
||||||
f"Total rewards mean {avg_metrics['total_rewards_mean']:.3f}, "
|
f"Total rewards mean {avg_metrics['total_rewards_mean']:.3f}, "
|
||||||
f"Total rewards std {avg_metrics['total_rewards_std']:.3f}, "
|
f"Total rewards std {avg_metrics['total_rewards_std']:.3f}, "
|
||||||
f"Grouped rewards mean {avg_metrics['grouped_rewards_mean']:.3f}, "
|
f"Grouped rewards mean {avg_metrics['grouped_rewards_mean']:.3f}, "
|
||||||
|
Loading…
Reference in New Issue
Block a user