mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 10:41:18 +08:00
adding function for R1
This commit is contained in:
parent
243c9621d9
commit
d034ca369e
@ -1,11 +1,9 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
import glob
|
|
||||||
import shutil
|
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
import re
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -38,40 +36,128 @@ class GRPOTrainingArgs(TrainingArgs):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def compute_default_rewards(sequences, batch_size, group_size):
|
def r1_extract_xml_answer(text: str) -> str:
|
||||||
"""
|
"""Extracts the answer from an XML formatted text string."""
|
||||||
|
try:
|
||||||
|
answer = text.split("<answer>")[-1]
|
||||||
|
answer = answer.split("</answer>")[0]
|
||||||
|
return answer.strip()
|
||||||
|
except:
|
||||||
|
print("[extract_xml_answer] Failed to extract answer from: ", text)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def r1_accuracy_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
|
||||||
|
"""Calculates reward based on accuracy of extracted answers.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sequences: List of word sequences
|
prompts: List of input prompts
|
||||||
batch_size: Number of original prompts
|
completions: List of completion strings
|
||||||
group_size: Number of generations per prompt
|
answer: Expected answer or list of answers
|
||||||
|
**kwargs: Additional arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[float]: Reward values for each completion
|
||||||
"""
|
"""
|
||||||
rewards = mx.zeros((len(sequences),))
|
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
||||||
|
q = prompts[0] if isinstance(prompts[0], str) else prompts[0][-1]['content']
|
||||||
|
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
|
||||||
|
|
||||||
for i, sequence in enumerate(sequences):
|
def r1_int_reward_func(completions, **kwargs) -> list[float]:
|
||||||
# Convert sequence to list if it isn't already
|
"""Rewards numerical responses.
|
||||||
if not isinstance(sequence, list):
|
|
||||||
sequence = sequence.split()
|
|
||||||
|
|
||||||
# Get the target (reversed) sequence
|
Args:
|
||||||
target = sequence[::-1]
|
completions: List of completion strings
|
||||||
|
**kwargs: Additional arguments
|
||||||
|
|
||||||
# Calculate accuracy of reversal
|
Returns:
|
||||||
correct_positions = sum(1 for a, b in zip(sequence, target) if a == b)
|
list[float]: Reward values for each completion
|
||||||
rewards[i] = correct_positions / len(sequence)
|
"""
|
||||||
|
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
||||||
|
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
|
||||||
|
|
||||||
return rewards
|
def r1_strict_format_reward_func(completions, **kwargs) -> list[float]:
|
||||||
|
"""Rewards completions with strict XML format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
completions: List of completion strings
|
||||||
|
**kwargs: Additional arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[float]: Reward values for each completion
|
||||||
|
"""
|
||||||
|
pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>\n$"
|
||||||
|
matches = [re.match(pattern, r) for r in completions]
|
||||||
|
return [0.5 if match else 0.0 for match in matches]
|
||||||
|
|
||||||
|
def r1_soft_format_reward_func(completions, **kwargs) -> list[float]:
|
||||||
|
"""Rewards completions with flexible XML format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
completions: List of completion strings
|
||||||
|
**kwargs: Additional arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[float]: Reward values for each completion
|
||||||
|
"""
|
||||||
|
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
|
||||||
|
matches = [re.match(pattern, r) for r in completions]
|
||||||
|
return [0.5 if match else 0.0 for match in matches]
|
||||||
|
|
||||||
|
def r1_count_xml(text: str) -> float:
|
||||||
|
"""Calculates score based on XML formatting.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Input text string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: Score based on XML tag presence and formatting
|
||||||
|
"""
|
||||||
|
count = 0.0
|
||||||
|
if text.count("<think>\n") == 1:
|
||||||
|
count += 0.125
|
||||||
|
if text.count("\n</think>\n") == 1:
|
||||||
|
count += 0.125
|
||||||
|
if text.count("\n<answer>\n") == 1:
|
||||||
|
count += 0.125
|
||||||
|
count -= len(text.split("\n</answer>\n")[-1])*0.001
|
||||||
|
if text.count("\n</answer>") == 1:
|
||||||
|
count += 0.125
|
||||||
|
count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
|
||||||
|
return count
|
||||||
|
|
||||||
|
|
||||||
def grpo_loss(
|
def grpo_loss(
|
||||||
model,
|
model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
prompts,
|
prompts,
|
||||||
reward_funcs=None,
|
reward_funcs=[
|
||||||
|
r1_accuracy_reward_func,
|
||||||
|
r1_int_reward_func,
|
||||||
|
r1_strict_format_reward_func,
|
||||||
|
r1_soft_format_reward_func,
|
||||||
|
r1_count_xml
|
||||||
|
],
|
||||||
beta=0.1,
|
beta=0.1,
|
||||||
group_size=4,
|
group_size=4,
|
||||||
epslion=1e-4,
|
epsilon=1e-4,
|
||||||
ref_model = None
|
ref_model=None
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Calculates the GRPO loss with support for multiple reward functions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The model to optimize
|
||||||
|
tokenizer: Tokenizer for processing inputs
|
||||||
|
prompts: List of input prompts
|
||||||
|
reward_funcs: List of reward functions to use
|
||||||
|
beta: KL penalty coefficient
|
||||||
|
group_size: Number of completions per prompt
|
||||||
|
epsilon: Small constant for numerical stability
|
||||||
|
ref_model: Optional reference model for KL divergence
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (loss, total_sequence_length, metrics_dict)
|
||||||
|
"""
|
||||||
batch_size = len(prompts)
|
batch_size = len(prompts)
|
||||||
# Generate multiple completions for each prompt
|
# Generate multiple completions for each prompt
|
||||||
all_completions = []
|
all_completions = []
|
||||||
@ -83,7 +169,7 @@ def grpo_loss(
|
|||||||
prompt_completions.append(completion)
|
prompt_completions.append(completion)
|
||||||
all_completions.extend(prompt_completions)
|
all_completions.extend(prompt_completions)
|
||||||
|
|
||||||
# Tokenize all prompts + completions (needed for model processing)
|
# Tokenize all prompts + completions
|
||||||
tokenized_inputs = tokenizer(
|
tokenized_inputs = tokenizer(
|
||||||
[p + c for p, c in zip(prompts * group_size, all_completions)],
|
[p + c for p, c in zip(prompts * group_size, all_completions)],
|
||||||
return_tensors="np",
|
return_tensors="np",
|
||||||
@ -102,7 +188,7 @@ def grpo_loss(
|
|||||||
# Calculate log probabilities
|
# Calculate log probabilities
|
||||||
log_probs = mx.log_softmax(logits[:, :-1, :], axis=-1)
|
log_probs = mx.log_softmax(logits[:, :-1, :], axis=-1)
|
||||||
|
|
||||||
# Prepare targets (shift input_ids left by one position)
|
# Prepare targets
|
||||||
targets = inputs[:, 1:]
|
targets = inputs[:, 1:]
|
||||||
|
|
||||||
# Gather actual token probabilities
|
# Gather actual token probabilities
|
||||||
@ -125,29 +211,33 @@ def grpo_loss(
|
|||||||
axis=-1
|
axis=-1
|
||||||
).squeeze(-1)
|
).squeeze(-1)
|
||||||
|
|
||||||
# Compute the KL divergence between the model and the reference model
|
# Compute KL divergence
|
||||||
kl_div = (mx.exp(ref_token_log_probs - token_log_probs) - (ref_token_log_probs - token_log_probs) - 1)
|
kl_div = (mx.exp(ref_token_log_probs - token_log_probs) - (ref_token_log_probs - token_log_probs) - 1)
|
||||||
|
|
||||||
# Calculate rewards
|
# Calculate combined rewards from all reward functions
|
||||||
if reward_funcs:
|
rewards = mx.zeros((len(all_completions),))
|
||||||
rewards = mx.array([sum(rf(all_completions) for rf in reward_funcs)])
|
for reward_func in reward_funcs:
|
||||||
else:
|
func_rewards = mx.array(reward_func(all_completions))
|
||||||
rewards = compute_default_rewards(all_completions, batch_size, group_size)
|
rewards += func_rewards
|
||||||
|
|
||||||
|
# Normalize rewards if using multiple reward functions
|
||||||
|
if len(reward_funcs) > 1:
|
||||||
|
rewards /= len(reward_funcs)
|
||||||
|
|
||||||
# Compute grouped-wise rewards
|
# Compute grouped-wise rewards
|
||||||
grouped_rewards = rewards.reshape(batch_size, group_size)
|
grouped_rewards = rewards.reshape(batch_size, group_size)
|
||||||
mean_grouped_rewards = mx.mean(grouped_rewards, axis=1)
|
mean_grouped_rewards = mx.mean(grouped_rewards, axis=1)
|
||||||
std_grouped_rewards = mx.std(grouped_rewards, axis=1)
|
std_grouped_rewards = mx.std(grouped_rewards, axis=1)
|
||||||
|
|
||||||
# Normalize the rewards to compute the advantages
|
# Normalize rewards to compute advantages
|
||||||
mean_grouped_rewards = mx.repeat(mean_grouped_rewards.reshape(-1, 1), group_size, axis=1).reshape(-1)
|
mean_grouped_rewards = mx.repeat(mean_grouped_rewards.reshape(-1, 1), group_size, axis=1).reshape(-1)
|
||||||
std_grouped_rewards = mx.repeat(std_grouped_rewards.reshape(-1, 1), group_size, axis=1).reshape(-1)
|
std_grouped_rewards = mx.repeat(std_grouped_rewards.reshape(-1, 1), group_size, axis=1).reshape(-1)
|
||||||
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + epslion)
|
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + epsilon)
|
||||||
|
|
||||||
# Create length mask for the shifted sequence
|
# Create length mask for the shifted sequence
|
||||||
length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1)
|
length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1)
|
||||||
|
|
||||||
# Calculate policy gradient loss, mx.stop_gradient allows for preserving gradients from token_log_probs
|
# Calculate policy gradient loss
|
||||||
per_token_loss = mx.exp(token_log_probs - mx.stop_gradient(token_log_probs)) * advantages.reshape(-1, 1)
|
per_token_loss = mx.exp(token_log_probs - mx.stop_gradient(token_log_probs)) * advantages.reshape(-1, 1)
|
||||||
per_token_loss = -(per_token_loss - beta * kl_div)
|
per_token_loss = -(per_token_loss - beta * kl_div)
|
||||||
|
|
||||||
@ -156,14 +246,24 @@ def grpo_loss(
|
|||||||
sequence_lengths = length_mask.sum(axis=1)
|
sequence_lengths = length_mask.sum(axis=1)
|
||||||
loss = (sequence_sums / sequence_lengths).mean()
|
loss = (sequence_sums / sequence_lengths).mean()
|
||||||
|
|
||||||
# Calculate mean KL divergence (normalized per sequence)
|
# Calculate mean KL divergence
|
||||||
mean_kl = ((kl_div * length_mask).sum(axis=1) / length_mask.sum(axis=1)).mean()
|
mean_kl = ((kl_div * length_mask).sum(axis=1) / length_mask.sum(axis=1)).mean()
|
||||||
|
|
||||||
|
# Collect metrics for each reward function separately
|
||||||
|
reward_metrics = {}
|
||||||
|
for i, reward_func in enumerate(reward_funcs):
|
||||||
|
func_rewards = mx.array(reward_func(all_completions))
|
||||||
|
func_grouped_rewards = func_rewards.reshape(batch_size, group_size)
|
||||||
|
reward_metrics[f'reward_func_{i}_mean'] = mx.mean(func_rewards)
|
||||||
|
reward_metrics[f'reward_func_{i}_std'] = mx.std(func_rewards)
|
||||||
|
|
||||||
metrics = {
|
metrics = {
|
||||||
'rewards': rewards,
|
'total_rewards_mean': mx.mean(rewards),
|
||||||
'rewards_std': mx.std(rewards),
|
'total_rewards_std': mx.std(rewards),
|
||||||
'grouped_rewards': grouped_rewards,
|
'grouped_rewards_mean': mx.mean(grouped_rewards),
|
||||||
'grouped_rewards_std': mx.std(grouped_rewards),
|
'grouped_rewards_std': mx.std(grouped_rewards),
|
||||||
'kl': mean_kl
|
'kl': mean_kl,
|
||||||
|
**reward_metrics
|
||||||
}
|
}
|
||||||
|
|
||||||
return loss, sequence_lengths.sum(), metrics
|
return loss, sequence_lengths.sum(), metrics
|
||||||
|
Loading…
Reference in New Issue
Block a user