adding function for R1

This commit is contained in:
Goekdeniz-Guelmez 2025-02-03 08:26:42 +01:00
parent 243c9621d9
commit d034ca369e

View File

@ -1,11 +1,9 @@
# Copyright © 2024 Apple Inc. # Copyright © 2024 Apple Inc.
import glob
import shutil
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Union import re
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -38,40 +36,128 @@ class GRPOTrainingArgs(TrainingArgs):
) )
def compute_default_rewards(sequences, batch_size, group_size): def r1_extract_xml_answer(text: str) -> str:
""" """Extracts the answer from an XML formatted text string."""
try:
answer = text.split("<answer>")[-1]
answer = answer.split("</answer>")[0]
return answer.strip()
except:
print("[extract_xml_answer] Failed to extract answer from: ", text)
return ""
def r1_accuracy_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
"""Calculates reward based on accuracy of extracted answers.
Args: Args:
sequences: List of word sequences prompts: List of input prompts
batch_size: Number of original prompts completions: List of completion strings
group_size: Number of generations per prompt answer: Expected answer or list of answers
**kwargs: Additional arguments
Returns:
list[float]: Reward values for each completion
""" """
rewards = mx.zeros((len(sequences),)) extracted_responses = [r1_extract_xml_answer(r) for r in completions]
q = prompts[0] if isinstance(prompts[0], str) else prompts[0][-1]['content']
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
for i, sequence in enumerate(sequences): def r1_int_reward_func(completions, **kwargs) -> list[float]:
# Convert sequence to list if it isn't already """Rewards numerical responses.
if not isinstance(sequence, list):
sequence = sequence.split()
# Get the target (reversed) sequence Args:
target = sequence[::-1] completions: List of completion strings
**kwargs: Additional arguments
# Calculate accuracy of reversal Returns:
correct_positions = sum(1 for a, b in zip(sequence, target) if a == b) list[float]: Reward values for each completion
rewards[i] = correct_positions / len(sequence) """
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
return rewards def r1_strict_format_reward_func(completions, **kwargs) -> list[float]:
"""Rewards completions with strict XML format.
Args:
completions: List of completion strings
**kwargs: Additional arguments
Returns:
list[float]: Reward values for each completion
"""
pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>\n$"
matches = [re.match(pattern, r) for r in completions]
return [0.5 if match else 0.0 for match in matches]
def r1_soft_format_reward_func(completions, **kwargs) -> list[float]:
"""Rewards completions with flexible XML format.
Args:
completions: List of completion strings
**kwargs: Additional arguments
Returns:
list[float]: Reward values for each completion
"""
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
matches = [re.match(pattern, r) for r in completions]
return [0.5 if match else 0.0 for match in matches]
def r1_count_xml(text: str) -> float:
"""Calculates score based on XML formatting.
Args:
text: Input text string
Returns:
float: Score based on XML tag presence and formatting
"""
count = 0.0
if text.count("<think>\n") == 1:
count += 0.125
if text.count("\n</think>\n") == 1:
count += 0.125
if text.count("\n<answer>\n") == 1:
count += 0.125
count -= len(text.split("\n</answer>\n")[-1])*0.001
if text.count("\n</answer>") == 1:
count += 0.125
count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
return count
def grpo_loss( def grpo_loss(
model, model,
tokenizer, tokenizer,
prompts, prompts,
reward_funcs=None, reward_funcs=[
r1_accuracy_reward_func,
r1_int_reward_func,
r1_strict_format_reward_func,
r1_soft_format_reward_func,
r1_count_xml
],
beta=0.1, beta=0.1,
group_size=4, group_size=4,
epslion=1e-4, epsilon=1e-4,
ref_model = None ref_model=None
): ):
"""
Calculates the GRPO loss with support for multiple reward functions.
Args:
model: The model to optimize
tokenizer: Tokenizer for processing inputs
prompts: List of input prompts
reward_funcs: List of reward functions to use
beta: KL penalty coefficient
group_size: Number of completions per prompt
epsilon: Small constant for numerical stability
ref_model: Optional reference model for KL divergence
Returns:
tuple: (loss, total_sequence_length, metrics_dict)
"""
batch_size = len(prompts) batch_size = len(prompts)
# Generate multiple completions for each prompt # Generate multiple completions for each prompt
all_completions = [] all_completions = []
@ -83,7 +169,7 @@ def grpo_loss(
prompt_completions.append(completion) prompt_completions.append(completion)
all_completions.extend(prompt_completions) all_completions.extend(prompt_completions)
# Tokenize all prompts + completions (needed for model processing) # Tokenize all prompts + completions
tokenized_inputs = tokenizer( tokenized_inputs = tokenizer(
[p + c for p, c in zip(prompts * group_size, all_completions)], [p + c for p, c in zip(prompts * group_size, all_completions)],
return_tensors="np", return_tensors="np",
@ -102,7 +188,7 @@ def grpo_loss(
# Calculate log probabilities # Calculate log probabilities
log_probs = mx.log_softmax(logits[:, :-1, :], axis=-1) log_probs = mx.log_softmax(logits[:, :-1, :], axis=-1)
# Prepare targets (shift input_ids left by one position) # Prepare targets
targets = inputs[:, 1:] targets = inputs[:, 1:]
# Gather actual token probabilities # Gather actual token probabilities
@ -125,29 +211,33 @@ def grpo_loss(
axis=-1 axis=-1
).squeeze(-1) ).squeeze(-1)
# Compute the KL divergence between the model and the reference model # Compute KL divergence
kl_div = (mx.exp(ref_token_log_probs - token_log_probs) - (ref_token_log_probs - token_log_probs) - 1) kl_div = (mx.exp(ref_token_log_probs - token_log_probs) - (ref_token_log_probs - token_log_probs) - 1)
# Calculate rewards # Calculate combined rewards from all reward functions
if reward_funcs: rewards = mx.zeros((len(all_completions),))
rewards = mx.array([sum(rf(all_completions) for rf in reward_funcs)]) for reward_func in reward_funcs:
else: func_rewards = mx.array(reward_func(all_completions))
rewards = compute_default_rewards(all_completions, batch_size, group_size) rewards += func_rewards
# Normalize rewards if using multiple reward functions
if len(reward_funcs) > 1:
rewards /= len(reward_funcs)
# Compute grouped-wise rewards # Compute grouped-wise rewards
grouped_rewards = rewards.reshape(batch_size, group_size) grouped_rewards = rewards.reshape(batch_size, group_size)
mean_grouped_rewards = mx.mean(grouped_rewards, axis=1) mean_grouped_rewards = mx.mean(grouped_rewards, axis=1)
std_grouped_rewards = mx.std(grouped_rewards, axis=1) std_grouped_rewards = mx.std(grouped_rewards, axis=1)
# Normalize the rewards to compute the advantages # Normalize rewards to compute advantages
mean_grouped_rewards = mx.repeat(mean_grouped_rewards.reshape(-1, 1), group_size, axis=1).reshape(-1) mean_grouped_rewards = mx.repeat(mean_grouped_rewards.reshape(-1, 1), group_size, axis=1).reshape(-1)
std_grouped_rewards = mx.repeat(std_grouped_rewards.reshape(-1, 1), group_size, axis=1).reshape(-1) std_grouped_rewards = mx.repeat(std_grouped_rewards.reshape(-1, 1), group_size, axis=1).reshape(-1)
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + epslion) advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + epsilon)
# Create length mask for the shifted sequence # Create length mask for the shifted sequence
length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1) length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1)
# Calculate policy gradient loss, mx.stop_gradient allows for preserving gradients from token_log_probs # Calculate policy gradient loss
per_token_loss = mx.exp(token_log_probs - mx.stop_gradient(token_log_probs)) * advantages.reshape(-1, 1) per_token_loss = mx.exp(token_log_probs - mx.stop_gradient(token_log_probs)) * advantages.reshape(-1, 1)
per_token_loss = -(per_token_loss - beta * kl_div) per_token_loss = -(per_token_loss - beta * kl_div)
@ -156,14 +246,24 @@ def grpo_loss(
sequence_lengths = length_mask.sum(axis=1) sequence_lengths = length_mask.sum(axis=1)
loss = (sequence_sums / sequence_lengths).mean() loss = (sequence_sums / sequence_lengths).mean()
# Calculate mean KL divergence (normalized per sequence) # Calculate mean KL divergence
mean_kl = ((kl_div * length_mask).sum(axis=1) / length_mask.sum(axis=1)).mean() mean_kl = ((kl_div * length_mask).sum(axis=1) / length_mask.sum(axis=1)).mean()
# Collect metrics for each reward function separately
reward_metrics = {}
for i, reward_func in enumerate(reward_funcs):
func_rewards = mx.array(reward_func(all_completions))
func_grouped_rewards = func_rewards.reshape(batch_size, group_size)
reward_metrics[f'reward_func_{i}_mean'] = mx.mean(func_rewards)
reward_metrics[f'reward_func_{i}_std'] = mx.std(func_rewards)
metrics = { metrics = {
'rewards': rewards, 'total_rewards_mean': mx.mean(rewards),
'rewards_std': mx.std(rewards), 'total_rewards_std': mx.std(rewards),
'grouped_rewards': grouped_rewards, 'grouped_rewards_mean': mx.mean(grouped_rewards),
'grouped_rewards_std': mx.std(grouped_rewards), 'grouped_rewards_std': mx.std(grouped_rewards),
'kl': mean_kl 'kl': mean_kl,
**reward_metrics
} }
return loss, sequence_lengths.sum(), metrics return loss, sequence_lengths.sum(), metrics