From d034ca369eed041cc84ef16b75de01cb517cab05 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Mon, 3 Feb 2025 08:26:42 +0100 Subject: [PATCH] adding function for R1 --- llms/mlx_lm/tuner/grpo_trainer.py | 182 +++++++++++++++++++++++------- 1 file changed, 141 insertions(+), 41 deletions(-) diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index c3b3007e..ac735264 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -1,11 +1,9 @@ # Copyright © 2024 Apple Inc. -import glob -import shutil import time from dataclasses import dataclass, field from pathlib import Path -from typing import Union +import re import mlx.core as mx import mlx.nn as nn @@ -38,40 +36,128 @@ class GRPOTrainingArgs(TrainingArgs): ) -def compute_default_rewards(sequences, batch_size, group_size): - """ +def r1_extract_xml_answer(text: str) -> str: + """Extracts the answer from an XML formatted text string.""" + try: + answer = text.split("")[-1] + answer = answer.split("")[0] + return answer.strip() + except: + print("[extract_xml_answer] Failed to extract answer from: ", text) + return "" + +def r1_accuracy_reward_func(prompts, completions, answer, **kwargs) -> list[float]: + """Calculates reward based on accuracy of extracted answers. + Args: - sequences: List of word sequences - batch_size: Number of original prompts - group_size: Number of generations per prompt - """ - rewards = mx.zeros((len(sequences),)) - - for i, sequence in enumerate(sequences): - # Convert sequence to list if it isn't already - if not isinstance(sequence, list): - sequence = sequence.split() - - # Get the target (reversed) sequence - target = sequence[::-1] + prompts: List of input prompts + completions: List of completion strings + answer: Expected answer or list of answers + **kwargs: Additional arguments - # Calculate accuracy of reversal - correct_positions = sum(1 for a, b in zip(sequence, target) if a == b) - rewards[i] = correct_positions / len(sequence) + Returns: + list[float]: Reward values for each completion + """ + extracted_responses = [r1_extract_xml_answer(r) for r in completions] + q = prompts[0] if isinstance(prompts[0], str) else prompts[0][-1]['content'] + return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)] + +def r1_int_reward_func(completions, **kwargs) -> list[float]: + """Rewards numerical responses. - return rewards + Args: + completions: List of completion strings + **kwargs: Additional arguments + + Returns: + list[float]: Reward values for each completion + """ + extracted_responses = [r1_extract_xml_answer(r) for r in completions] + return [0.5 if r.isdigit() else 0.0 for r in extracted_responses] + +def r1_strict_format_reward_func(completions, **kwargs) -> list[float]: + """Rewards completions with strict XML format. + + Args: + completions: List of completion strings + **kwargs: Additional arguments + + Returns: + list[float]: Reward values for each completion + """ + pattern = r"^\n.*?\n\n\n.*?\n\n$" + matches = [re.match(pattern, r) for r in completions] + return [0.5 if match else 0.0 for match in matches] + +def r1_soft_format_reward_func(completions, **kwargs) -> list[float]: + """Rewards completions with flexible XML format. + + Args: + completions: List of completion strings + **kwargs: Additional arguments + + Returns: + list[float]: Reward values for each completion + """ + pattern = r".*?\s*.*?" + matches = [re.match(pattern, r) for r in completions] + return [0.5 if match else 0.0 for match in matches] + +def r1_count_xml(text: str) -> float: + """Calculates score based on XML formatting. + + Args: + text: Input text string + + Returns: + float: Score based on XML tag presence and formatting + """ + count = 0.0 + if text.count("\n") == 1: + count += 0.125 + if text.count("\n\n") == 1: + count += 0.125 + if text.count("\n\n") == 1: + count += 0.125 + count -= len(text.split("\n\n")[-1])*0.001 + if text.count("\n") == 1: + count += 0.125 + count -= (len(text.split("\n")[-1]) - 1)*0.001 + return count def grpo_loss( model, tokenizer, prompts, - reward_funcs=None, + reward_funcs=[ + r1_accuracy_reward_func, + r1_int_reward_func, + r1_strict_format_reward_func, + r1_soft_format_reward_func, + r1_count_xml + ], beta=0.1, group_size=4, - epslion=1e-4, - ref_model = None + epsilon=1e-4, + ref_model=None ): + """ + Calculates the GRPO loss with support for multiple reward functions. + + Args: + model: The model to optimize + tokenizer: Tokenizer for processing inputs + prompts: List of input prompts + reward_funcs: List of reward functions to use + beta: KL penalty coefficient + group_size: Number of completions per prompt + epsilon: Small constant for numerical stability + ref_model: Optional reference model for KL divergence + + Returns: + tuple: (loss, total_sequence_length, metrics_dict) + """ batch_size = len(prompts) # Generate multiple completions for each prompt all_completions = [] @@ -83,7 +169,7 @@ def grpo_loss( prompt_completions.append(completion) all_completions.extend(prompt_completions) - # Tokenize all prompts + completions (needed for model processing) + # Tokenize all prompts + completions tokenized_inputs = tokenizer( [p + c for p, c in zip(prompts * group_size, all_completions)], return_tensors="np", @@ -102,7 +188,7 @@ def grpo_loss( # Calculate log probabilities log_probs = mx.log_softmax(logits[:, :-1, :], axis=-1) - # Prepare targets (shift input_ids left by one position) + # Prepare targets targets = inputs[:, 1:] # Gather actual token probabilities @@ -125,29 +211,33 @@ def grpo_loss( axis=-1 ).squeeze(-1) - # Compute the KL divergence between the model and the reference model + # Compute KL divergence kl_div = (mx.exp(ref_token_log_probs - token_log_probs) - (ref_token_log_probs - token_log_probs) - 1) - # Calculate rewards - if reward_funcs: - rewards = mx.array([sum(rf(all_completions) for rf in reward_funcs)]) - else: - rewards = compute_default_rewards(all_completions, batch_size, group_size) + # Calculate combined rewards from all reward functions + rewards = mx.zeros((len(all_completions),)) + for reward_func in reward_funcs: + func_rewards = mx.array(reward_func(all_completions)) + rewards += func_rewards + + # Normalize rewards if using multiple reward functions + if len(reward_funcs) > 1: + rewards /= len(reward_funcs) # Compute grouped-wise rewards grouped_rewards = rewards.reshape(batch_size, group_size) mean_grouped_rewards = mx.mean(grouped_rewards, axis=1) std_grouped_rewards = mx.std(grouped_rewards, axis=1) - # Normalize the rewards to compute the advantages + # Normalize rewards to compute advantages mean_grouped_rewards = mx.repeat(mean_grouped_rewards.reshape(-1, 1), group_size, axis=1).reshape(-1) std_grouped_rewards = mx.repeat(std_grouped_rewards.reshape(-1, 1), group_size, axis=1).reshape(-1) - advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + epslion) + advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + epsilon) # Create length mask for the shifted sequence length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1) - # Calculate policy gradient loss, mx.stop_gradient allows for preserving gradients from token_log_probs + # Calculate policy gradient loss per_token_loss = mx.exp(token_log_probs - mx.stop_gradient(token_log_probs)) * advantages.reshape(-1, 1) per_token_loss = -(per_token_loss - beta * kl_div) @@ -156,14 +246,24 @@ def grpo_loss( sequence_lengths = length_mask.sum(axis=1) loss = (sequence_sums / sequence_lengths).mean() - # Calculate mean KL divergence (normalized per sequence) + # Calculate mean KL divergence mean_kl = ((kl_div * length_mask).sum(axis=1) / length_mask.sum(axis=1)).mean() + + # Collect metrics for each reward function separately + reward_metrics = {} + for i, reward_func in enumerate(reward_funcs): + func_rewards = mx.array(reward_func(all_completions)) + func_grouped_rewards = func_rewards.reshape(batch_size, group_size) + reward_metrics[f'reward_func_{i}_mean'] = mx.mean(func_rewards) + reward_metrics[f'reward_func_{i}_std'] = mx.std(func_rewards) + metrics = { - 'rewards': rewards, - 'rewards_std': mx.std(rewards), - 'grouped_rewards': grouped_rewards, + 'total_rewards_mean': mx.mean(rewards), + 'total_rewards_std': mx.std(rewards), + 'grouped_rewards_mean': mx.mean(grouped_rewards), 'grouped_rewards_std': mx.std(grouped_rewards), - 'kl': mean_kl + 'kl': mean_kl, + **reward_metrics } return loss, sequence_lengths.sum(), metrics