adding function for R1

This commit is contained in:
Goekdeniz-Guelmez 2025-02-03 08:26:42 +01:00
parent 243c9621d9
commit d034ca369e

View File

@ -1,11 +1,9 @@
# Copyright © 2024 Apple Inc.
import glob
import shutil
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Union
import re
import mlx.core as mx
import mlx.nn as nn
@ -38,40 +36,128 @@ class GRPOTrainingArgs(TrainingArgs):
)
def compute_default_rewards(sequences, batch_size, group_size):
"""
def r1_extract_xml_answer(text: str) -> str:
"""Extracts the answer from an XML formatted text string."""
try:
answer = text.split("<answer>")[-1]
answer = answer.split("</answer>")[0]
return answer.strip()
except:
print("[extract_xml_answer] Failed to extract answer from: ", text)
return ""
def r1_accuracy_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
"""Calculates reward based on accuracy of extracted answers.
Args:
sequences: List of word sequences
batch_size: Number of original prompts
group_size: Number of generations per prompt
"""
rewards = mx.zeros((len(sequences),))
for i, sequence in enumerate(sequences):
# Convert sequence to list if it isn't already
if not isinstance(sequence, list):
sequence = sequence.split()
# Get the target (reversed) sequence
target = sequence[::-1]
prompts: List of input prompts
completions: List of completion strings
answer: Expected answer or list of answers
**kwargs: Additional arguments
# Calculate accuracy of reversal
correct_positions = sum(1 for a, b in zip(sequence, target) if a == b)
rewards[i] = correct_positions / len(sequence)
Returns:
list[float]: Reward values for each completion
"""
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
q = prompts[0] if isinstance(prompts[0], str) else prompts[0][-1]['content']
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
def r1_int_reward_func(completions, **kwargs) -> list[float]:
"""Rewards numerical responses.
return rewards
Args:
completions: List of completion strings
**kwargs: Additional arguments
Returns:
list[float]: Reward values for each completion
"""
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
def r1_strict_format_reward_func(completions, **kwargs) -> list[float]:
"""Rewards completions with strict XML format.
Args:
completions: List of completion strings
**kwargs: Additional arguments
Returns:
list[float]: Reward values for each completion
"""
pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>\n$"
matches = [re.match(pattern, r) for r in completions]
return [0.5 if match else 0.0 for match in matches]
def r1_soft_format_reward_func(completions, **kwargs) -> list[float]:
"""Rewards completions with flexible XML format.
Args:
completions: List of completion strings
**kwargs: Additional arguments
Returns:
list[float]: Reward values for each completion
"""
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
matches = [re.match(pattern, r) for r in completions]
return [0.5 if match else 0.0 for match in matches]
def r1_count_xml(text: str) -> float:
"""Calculates score based on XML formatting.
Args:
text: Input text string
Returns:
float: Score based on XML tag presence and formatting
"""
count = 0.0
if text.count("<think>\n") == 1:
count += 0.125
if text.count("\n</think>\n") == 1:
count += 0.125
if text.count("\n<answer>\n") == 1:
count += 0.125
count -= len(text.split("\n</answer>\n")[-1])*0.001
if text.count("\n</answer>") == 1:
count += 0.125
count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
return count
def grpo_loss(
model,
tokenizer,
prompts,
reward_funcs=None,
reward_funcs=[
r1_accuracy_reward_func,
r1_int_reward_func,
r1_strict_format_reward_func,
r1_soft_format_reward_func,
r1_count_xml
],
beta=0.1,
group_size=4,
epslion=1e-4,
ref_model = None
epsilon=1e-4,
ref_model=None
):
"""
Calculates the GRPO loss with support for multiple reward functions.
Args:
model: The model to optimize
tokenizer: Tokenizer for processing inputs
prompts: List of input prompts
reward_funcs: List of reward functions to use
beta: KL penalty coefficient
group_size: Number of completions per prompt
epsilon: Small constant for numerical stability
ref_model: Optional reference model for KL divergence
Returns:
tuple: (loss, total_sequence_length, metrics_dict)
"""
batch_size = len(prompts)
# Generate multiple completions for each prompt
all_completions = []
@ -83,7 +169,7 @@ def grpo_loss(
prompt_completions.append(completion)
all_completions.extend(prompt_completions)
# Tokenize all prompts + completions (needed for model processing)
# Tokenize all prompts + completions
tokenized_inputs = tokenizer(
[p + c for p, c in zip(prompts * group_size, all_completions)],
return_tensors="np",
@ -102,7 +188,7 @@ def grpo_loss(
# Calculate log probabilities
log_probs = mx.log_softmax(logits[:, :-1, :], axis=-1)
# Prepare targets (shift input_ids left by one position)
# Prepare targets
targets = inputs[:, 1:]
# Gather actual token probabilities
@ -125,29 +211,33 @@ def grpo_loss(
axis=-1
).squeeze(-1)
# Compute the KL divergence between the model and the reference model
# Compute KL divergence
kl_div = (mx.exp(ref_token_log_probs - token_log_probs) - (ref_token_log_probs - token_log_probs) - 1)
# Calculate rewards
if reward_funcs:
rewards = mx.array([sum(rf(all_completions) for rf in reward_funcs)])
else:
rewards = compute_default_rewards(all_completions, batch_size, group_size)
# Calculate combined rewards from all reward functions
rewards = mx.zeros((len(all_completions),))
for reward_func in reward_funcs:
func_rewards = mx.array(reward_func(all_completions))
rewards += func_rewards
# Normalize rewards if using multiple reward functions
if len(reward_funcs) > 1:
rewards /= len(reward_funcs)
# Compute grouped-wise rewards
grouped_rewards = rewards.reshape(batch_size, group_size)
mean_grouped_rewards = mx.mean(grouped_rewards, axis=1)
std_grouped_rewards = mx.std(grouped_rewards, axis=1)
# Normalize the rewards to compute the advantages
# Normalize rewards to compute advantages
mean_grouped_rewards = mx.repeat(mean_grouped_rewards.reshape(-1, 1), group_size, axis=1).reshape(-1)
std_grouped_rewards = mx.repeat(std_grouped_rewards.reshape(-1, 1), group_size, axis=1).reshape(-1)
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + epslion)
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + epsilon)
# Create length mask for the shifted sequence
length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1)
# Calculate policy gradient loss, mx.stop_gradient allows for preserving gradients from token_log_probs
# Calculate policy gradient loss
per_token_loss = mx.exp(token_log_probs - mx.stop_gradient(token_log_probs)) * advantages.reshape(-1, 1)
per_token_loss = -(per_token_loss - beta * kl_div)
@ -156,14 +246,24 @@ def grpo_loss(
sequence_lengths = length_mask.sum(axis=1)
loss = (sequence_sums / sequence_lengths).mean()
# Calculate mean KL divergence (normalized per sequence)
# Calculate mean KL divergence
mean_kl = ((kl_div * length_mask).sum(axis=1) / length_mask.sum(axis=1)).mean()
# Collect metrics for each reward function separately
reward_metrics = {}
for i, reward_func in enumerate(reward_funcs):
func_rewards = mx.array(reward_func(all_completions))
func_grouped_rewards = func_rewards.reshape(batch_size, group_size)
reward_metrics[f'reward_func_{i}_mean'] = mx.mean(func_rewards)
reward_metrics[f'reward_func_{i}_std'] = mx.std(func_rewards)
metrics = {
'rewards': rewards,
'rewards_std': mx.std(rewards),
'grouped_rewards': grouped_rewards,
'total_rewards_mean': mx.mean(rewards),
'total_rewards_std': mx.std(rewards),
'grouped_rewards_mean': mx.mean(grouped_rewards),
'grouped_rewards_std': mx.std(grouped_rewards),
'kl': mean_kl
'kl': mean_kl,
**reward_metrics
}
return loss, sequence_lengths.sum(), metrics