mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 02:33:23 +08:00
adding function for R1
This commit is contained in:
parent
243c9621d9
commit
d034ca369e
@ -1,11 +1,9 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import glob
|
||||
import shutil
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
import re
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@ -38,40 +36,128 @@ class GRPOTrainingArgs(TrainingArgs):
|
||||
)
|
||||
|
||||
|
||||
def compute_default_rewards(sequences, batch_size, group_size):
|
||||
"""
|
||||
def r1_extract_xml_answer(text: str) -> str:
|
||||
"""Extracts the answer from an XML formatted text string."""
|
||||
try:
|
||||
answer = text.split("<answer>")[-1]
|
||||
answer = answer.split("</answer>")[0]
|
||||
return answer.strip()
|
||||
except:
|
||||
print("[extract_xml_answer] Failed to extract answer from: ", text)
|
||||
return ""
|
||||
|
||||
def r1_accuracy_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
|
||||
"""Calculates reward based on accuracy of extracted answers.
|
||||
|
||||
Args:
|
||||
sequences: List of word sequences
|
||||
batch_size: Number of original prompts
|
||||
group_size: Number of generations per prompt
|
||||
"""
|
||||
rewards = mx.zeros((len(sequences),))
|
||||
|
||||
for i, sequence in enumerate(sequences):
|
||||
# Convert sequence to list if it isn't already
|
||||
if not isinstance(sequence, list):
|
||||
sequence = sequence.split()
|
||||
|
||||
# Get the target (reversed) sequence
|
||||
target = sequence[::-1]
|
||||
prompts: List of input prompts
|
||||
completions: List of completion strings
|
||||
answer: Expected answer or list of answers
|
||||
**kwargs: Additional arguments
|
||||
|
||||
# Calculate accuracy of reversal
|
||||
correct_positions = sum(1 for a, b in zip(sequence, target) if a == b)
|
||||
rewards[i] = correct_positions / len(sequence)
|
||||
Returns:
|
||||
list[float]: Reward values for each completion
|
||||
"""
|
||||
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
||||
q = prompts[0] if isinstance(prompts[0], str) else prompts[0][-1]['content']
|
||||
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
|
||||
|
||||
def r1_int_reward_func(completions, **kwargs) -> list[float]:
|
||||
"""Rewards numerical responses.
|
||||
|
||||
return rewards
|
||||
Args:
|
||||
completions: List of completion strings
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
list[float]: Reward values for each completion
|
||||
"""
|
||||
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
||||
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
|
||||
|
||||
def r1_strict_format_reward_func(completions, **kwargs) -> list[float]:
|
||||
"""Rewards completions with strict XML format.
|
||||
|
||||
Args:
|
||||
completions: List of completion strings
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
list[float]: Reward values for each completion
|
||||
"""
|
||||
pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>\n$"
|
||||
matches = [re.match(pattern, r) for r in completions]
|
||||
return [0.5 if match else 0.0 for match in matches]
|
||||
|
||||
def r1_soft_format_reward_func(completions, **kwargs) -> list[float]:
|
||||
"""Rewards completions with flexible XML format.
|
||||
|
||||
Args:
|
||||
completions: List of completion strings
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
list[float]: Reward values for each completion
|
||||
"""
|
||||
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
|
||||
matches = [re.match(pattern, r) for r in completions]
|
||||
return [0.5 if match else 0.0 for match in matches]
|
||||
|
||||
def r1_count_xml(text: str) -> float:
|
||||
"""Calculates score based on XML formatting.
|
||||
|
||||
Args:
|
||||
text: Input text string
|
||||
|
||||
Returns:
|
||||
float: Score based on XML tag presence and formatting
|
||||
"""
|
||||
count = 0.0
|
||||
if text.count("<think>\n") == 1:
|
||||
count += 0.125
|
||||
if text.count("\n</think>\n") == 1:
|
||||
count += 0.125
|
||||
if text.count("\n<answer>\n") == 1:
|
||||
count += 0.125
|
||||
count -= len(text.split("\n</answer>\n")[-1])*0.001
|
||||
if text.count("\n</answer>") == 1:
|
||||
count += 0.125
|
||||
count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
|
||||
return count
|
||||
|
||||
|
||||
def grpo_loss(
|
||||
model,
|
||||
tokenizer,
|
||||
prompts,
|
||||
reward_funcs=None,
|
||||
reward_funcs=[
|
||||
r1_accuracy_reward_func,
|
||||
r1_int_reward_func,
|
||||
r1_strict_format_reward_func,
|
||||
r1_soft_format_reward_func,
|
||||
r1_count_xml
|
||||
],
|
||||
beta=0.1,
|
||||
group_size=4,
|
||||
epslion=1e-4,
|
||||
ref_model = None
|
||||
epsilon=1e-4,
|
||||
ref_model=None
|
||||
):
|
||||
"""
|
||||
Calculates the GRPO loss with support for multiple reward functions.
|
||||
|
||||
Args:
|
||||
model: The model to optimize
|
||||
tokenizer: Tokenizer for processing inputs
|
||||
prompts: List of input prompts
|
||||
reward_funcs: List of reward functions to use
|
||||
beta: KL penalty coefficient
|
||||
group_size: Number of completions per prompt
|
||||
epsilon: Small constant for numerical stability
|
||||
ref_model: Optional reference model for KL divergence
|
||||
|
||||
Returns:
|
||||
tuple: (loss, total_sequence_length, metrics_dict)
|
||||
"""
|
||||
batch_size = len(prompts)
|
||||
# Generate multiple completions for each prompt
|
||||
all_completions = []
|
||||
@ -83,7 +169,7 @@ def grpo_loss(
|
||||
prompt_completions.append(completion)
|
||||
all_completions.extend(prompt_completions)
|
||||
|
||||
# Tokenize all prompts + completions (needed for model processing)
|
||||
# Tokenize all prompts + completions
|
||||
tokenized_inputs = tokenizer(
|
||||
[p + c for p, c in zip(prompts * group_size, all_completions)],
|
||||
return_tensors="np",
|
||||
@ -102,7 +188,7 @@ def grpo_loss(
|
||||
# Calculate log probabilities
|
||||
log_probs = mx.log_softmax(logits[:, :-1, :], axis=-1)
|
||||
|
||||
# Prepare targets (shift input_ids left by one position)
|
||||
# Prepare targets
|
||||
targets = inputs[:, 1:]
|
||||
|
||||
# Gather actual token probabilities
|
||||
@ -125,29 +211,33 @@ def grpo_loss(
|
||||
axis=-1
|
||||
).squeeze(-1)
|
||||
|
||||
# Compute the KL divergence between the model and the reference model
|
||||
# Compute KL divergence
|
||||
kl_div = (mx.exp(ref_token_log_probs - token_log_probs) - (ref_token_log_probs - token_log_probs) - 1)
|
||||
|
||||
# Calculate rewards
|
||||
if reward_funcs:
|
||||
rewards = mx.array([sum(rf(all_completions) for rf in reward_funcs)])
|
||||
else:
|
||||
rewards = compute_default_rewards(all_completions, batch_size, group_size)
|
||||
# Calculate combined rewards from all reward functions
|
||||
rewards = mx.zeros((len(all_completions),))
|
||||
for reward_func in reward_funcs:
|
||||
func_rewards = mx.array(reward_func(all_completions))
|
||||
rewards += func_rewards
|
||||
|
||||
# Normalize rewards if using multiple reward functions
|
||||
if len(reward_funcs) > 1:
|
||||
rewards /= len(reward_funcs)
|
||||
|
||||
# Compute grouped-wise rewards
|
||||
grouped_rewards = rewards.reshape(batch_size, group_size)
|
||||
mean_grouped_rewards = mx.mean(grouped_rewards, axis=1)
|
||||
std_grouped_rewards = mx.std(grouped_rewards, axis=1)
|
||||
|
||||
# Normalize the rewards to compute the advantages
|
||||
# Normalize rewards to compute advantages
|
||||
mean_grouped_rewards = mx.repeat(mean_grouped_rewards.reshape(-1, 1), group_size, axis=1).reshape(-1)
|
||||
std_grouped_rewards = mx.repeat(std_grouped_rewards.reshape(-1, 1), group_size, axis=1).reshape(-1)
|
||||
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + epslion)
|
||||
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + epsilon)
|
||||
|
||||
# Create length mask for the shifted sequence
|
||||
length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1)
|
||||
|
||||
# Calculate policy gradient loss, mx.stop_gradient allows for preserving gradients from token_log_probs
|
||||
# Calculate policy gradient loss
|
||||
per_token_loss = mx.exp(token_log_probs - mx.stop_gradient(token_log_probs)) * advantages.reshape(-1, 1)
|
||||
per_token_loss = -(per_token_loss - beta * kl_div)
|
||||
|
||||
@ -156,14 +246,24 @@ def grpo_loss(
|
||||
sequence_lengths = length_mask.sum(axis=1)
|
||||
loss = (sequence_sums / sequence_lengths).mean()
|
||||
|
||||
# Calculate mean KL divergence (normalized per sequence)
|
||||
# Calculate mean KL divergence
|
||||
mean_kl = ((kl_div * length_mask).sum(axis=1) / length_mask.sum(axis=1)).mean()
|
||||
|
||||
# Collect metrics for each reward function separately
|
||||
reward_metrics = {}
|
||||
for i, reward_func in enumerate(reward_funcs):
|
||||
func_rewards = mx.array(reward_func(all_completions))
|
||||
func_grouped_rewards = func_rewards.reshape(batch_size, group_size)
|
||||
reward_metrics[f'reward_func_{i}_mean'] = mx.mean(func_rewards)
|
||||
reward_metrics[f'reward_func_{i}_std'] = mx.std(func_rewards)
|
||||
|
||||
metrics = {
|
||||
'rewards': rewards,
|
||||
'rewards_std': mx.std(rewards),
|
||||
'grouped_rewards': grouped_rewards,
|
||||
'total_rewards_mean': mx.mean(rewards),
|
||||
'total_rewards_std': mx.std(rewards),
|
||||
'grouped_rewards_mean': mx.mean(grouped_rewards),
|
||||
'grouped_rewards_std': mx.std(grouped_rewards),
|
||||
'kl': mean_kl
|
||||
'kl': mean_kl,
|
||||
**reward_metrics
|
||||
}
|
||||
|
||||
return loss, sequence_lengths.sum(), metrics
|
||||
|
Loading…
Reference in New Issue
Block a user