removing comments + adding temperature + reward weighting

This commit is contained in:
Goekdeniz-Guelmez 2025-02-15 15:29:22 +01:00
parent baeb9f117f
commit 5ec4790656
2 changed files with 64 additions and 52 deletions

View File

@ -1,21 +1,21 @@
# Copyright © 2024 Apple Inc.
from pathlib import Path
import argparse
import types
import math
import os
import re
import types
from pathlib import Path
import mlx.nn as nn
import mlx.optimizers as optim
import mlx.nn as nn
import numpy as np
import yaml
from .tuner.grpo_trainer import GRPOTrainingArgs, evaluate_grpo, train_grpo
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tokenizer_utils import TokenizerWrapper
from .tuner.datasets import load_dataset
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tuner.grpo_trainer import GRPOTrainingArgs, evaluate_grpo, train_grpo
from .tuner.utils import (
build_schedule,
linear_to_lora_layers,
@ -73,6 +73,8 @@ CONFIG_DEFAULTS = {
"max_completion_length": 512,
"use_chat_template": False,
"use_prompt": False,
"temperature": 1.0,
"reward_weights": None,
}
@ -224,6 +226,18 @@ def build_parser():
help="Rather to use the prompt from the R1 paper.",
default=None,
)
parser.add_argument(
"--temperature",
type=float,
help="Temperature for sampling. The higher the temperature, the more random the completions.",
default=1.0,
)
parser.add_argument(
"--reward-weights",
type=str,
help="Weights for each reward function. Must match the number of reward functions and be in this format [0.1, 0.2, 0.3, 0.4, 0.5]. If not given, all rewards are weighted equally with weight `1.0`.",
default=None,
)
return parser
def train_model_grpo(model, tokenizer, args, opt, train_set, valid_set, adapter_file, training_callback):
@ -241,7 +255,9 @@ def train_model_grpo(model, tokenizer, args, opt, train_set, valid_set, adapter_
beta=args.beta,
group_size=args.group_size,
epsilon=args.epsilon,
reference_model_path=args.reference_model_path
reference_model_path=args.reference_model_path,
temperature=args.temperature,
reward_weights=[float(x) for x in args.reward_weights.strip('[]').split(',')] if args.reward_weights else None
)
if args.reference_model_path:

View File

@ -35,13 +35,24 @@ class GRPOTrainingArgs(TrainingArgs):
"help": "Path to reference model weights. If None, uses the same model."
}
)
temperature: float = field(
default=1.0,
metadata={
"help": "Temperature for sampling. The higher the temperature, the more random the completions."
}
)
reward_weights: Optional[List[float]] = field(
default=None,
metadata={
"help": "Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are weighted equally with weight `1.0`."
}
)
RewardFunctions = Callable[[List[str], List[str], List[str]], List[float]]
def r1_extract_xml_answer(text: str) -> str:
"""Extracts the answer from an XML formatted text string."""
try:
answer = text.split("<answer>")[-1]
answer = answer.split("</answer>")[0]
@ -52,14 +63,12 @@ def r1_extract_xml_answer(text: str) -> str:
def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
"""Ensures we always return a list of floats."""
if not completions:
return [0.0] * len(prompts)
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
return [0.5 if r and r.isdigit() else 0.0 for r in extracted_responses]
def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
"""Ensures we always return a list of floats."""
if not completions or not answer:
return [0.0] * len(prompts)
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
@ -67,7 +76,6 @@ def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kw
def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
"""Ensures we always return a list of floats."""
if not completions:
return [0.0] * len(prompts)
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
@ -76,7 +84,6 @@ def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, *
def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
"""Ensures we always return a list of floats."""
if not completions:
return [0.0] * len(prompts)
pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>\n$"
@ -85,16 +92,13 @@ def r1_strict_format_reward_func(prompts: list, completions: list, answer: list,
def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
"""Ensures we always return a list of floats."""
if not completions:
return [0.0] * len(prompts)
scores = []
for text in completions:
if not text:
scores.append(0.0)
continue
count = 0.0
if text.count("<think>\n") == 1:
count += 0.125
@ -104,13 +108,9 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li
count += 0.125
if text.count("\n</answer>\n") == 1:
count += 0.125
# Penalize extra text after </answer>
end_text = text.split("\n</answer>\n")[-1]
count -= len(end_text) * 0.001 if len(end_text) > 0 else 0
scores.append(max(0.0, count)) # Ensure non-negative score
scores.append(max(0.0, count))
return scores
@ -119,22 +119,18 @@ def generate_grpo(model: nn.Module, prompt, max_tokens, tokenizer, temperature):
prompt = prompt[None, :]
if prompt.shape[1] == 0:
return None
end_sequence = tokenizer.encode("</answer>")
end_sequence_length = len(end_sequence)
initial_length = prompt.shape[1]
output = mx.zeros((initial_length + max_tokens,), dtype=mx.int32)
output[:initial_length] = prompt[0]
current_length = initial_length
try:
def sample(logits):
if temperature > 0:
logits /= temperature
logprobs = logits - mx.logsumexp(logits, keepdims=True)
return mx.random.categorical(logprobs[None, :]).astype(mx.int32)[0]
for _ in range(max_tokens):
current_input = output[:current_length][None, :]
logits = model(current_input)
@ -143,18 +139,14 @@ def generate_grpo(model: nn.Module, prompt, max_tokens, tokenizer, temperature):
token_value = next_token.item()
output[current_length] = token_value
current_length += 1
if token_value == tokenizer.eos_token_id:
break
if current_length >= end_sequence_length:
last_tokens = output[current_length - end_sequence_length:current_length].tolist()
if last_tokens == end_sequence:
break
if current_length > initial_length:
return output[:current_length]
except Exception as e:
print(f"Generation error: {str(e)}")
return None
@ -192,9 +184,10 @@ def grpo_loss(
group_size=4,
epsilon=1e-4,
max_tokens=64,
temperature=1.0
temperature=1.0,
reward_weights=None
):
prompt_tokens, answer_tokens, prompt_text, answer_text = batch
prompt_tokens, _, prompt_text, answer_text = batch
batch_size = len(prompt_tokens)
all_completions = []
@ -273,18 +266,34 @@ def grpo_loss(
token_log_probs = mx.stack(padded_log_probs)
ref_token_log_probs = mx.stack(padded_ref_log_probs)
# Rewards and advantages
rewards = mx.zeros((len(all_completions),))
# Create array to store rewards from each function
all_func_rewards = []
# Collect rewards from each function separately
for reward_func in reward_funcs:
func_rewards = mx.array(reward_func(
prompts=expanded_prompts,
completions=all_completion_texts,
answer=expanded_answers
))
rewards += func_rewards
all_func_rewards.append(func_rewards)
if len(reward_funcs) > 1:
rewards /= len(reward_funcs)
# Stack rewards to shape (num_samples, num_funcs)
rewards = mx.stack(all_func_rewards, axis=1)
print(f"Rewards: {rewards}")
# Apply weights and sum
if reward_weights is not None:
if len(reward_weights) != len(reward_funcs):
raise ValueError(
f"Number of reward weights ({len(reward_weights)}) must match number of reward "
f"functions ({len(reward_funcs)})"
)
reward_weights = mx.array(reward_weights, dtype=mx.float32)
else:
reward_weights = mx.ones(len(reward_funcs), dtype=mx.float32)
rewards = (rewards * mx.expand_dims(reward_weights, 0)).sum(axis=1)
print(f"Rewards after weights: {rewards}")
# Reshape rewards and compute advantages
rewards_reshaped = rewards.reshape(batch_size, group_size)
@ -397,15 +406,11 @@ def evaluate_grpo(
epsilon: float,
group_size: int,
max_seq_length,
temperature: float,
reward_funcs: Optional[List[RewardFunctions]] = None,
loss_fn: callable = grpo_loss,
iterate_batches: callable = iterate_grpo_batches
):
"""
Evaluate model using GRPO loss.
Returns:
tuple: (average loss, number of tokens, average metrics)
"""
all_losses = 0
ntokens = 0
all_metrics = None
@ -428,7 +433,8 @@ def evaluate_grpo(
beta=beta,
group_size=group_size,
epsilon=epsilon,
ref_model=ref_model
ref_model=ref_model,
temperature=temperature
)
all_losses += losses * toks
@ -442,12 +448,10 @@ def evaluate_grpo(
mx.eval(all_losses, ntokens)
# Aggregate across distributed workers
all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu)
ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu)
all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()}
# Calculate averages
avg_metrics = {k: (v / ntokens).item() for k, v in all_metrics.items()}
avg_loss = (all_losses / ntokens).item()
@ -486,8 +490,6 @@ def train_grpo(
state = [model.state, optimizer.state]
def step(batch):
# Forward and backward pass
(loss, toks, metrics), grad = loss_value_and_grad(
model,
tokenizer=tokenizer,
@ -498,12 +500,11 @@ def train_grpo(
epsilon=args.epsilon,
ref_model=ref_model,
max_tokens=args.max_completion_length,
temperature=args.temperature
)
# All reduce the gradients if running in distributed mode
grad = average_gradients(grad)
# Model update
optimizer.update(model, grad)
return loss, toks, metrics
@ -536,8 +537,6 @@ def train_grpo(
train=True,
),
):
# Report validation loss if needed, the first validation loss
# is always measured before any training.
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
stop = time.perf_counter()
val_loss, val_ntokens, val_metrics = evaluate_grpo(
@ -553,6 +552,7 @@ def train_grpo(
max_seq_length=args.max_seq_length,
beta=args.beta,
epsilon=args.epsilon,
temperature=args.temperature,
iterate_batches=iterate_batches,
)
val_time = time.perf_counter() - stop
@ -566,7 +566,6 @@ def train_grpo(
f"Val kl {val_metrics['kl']:.3f}"
)
# Add reward function specific metrics
for i, reward_func in enumerate(reward_funcs):
val_metrics_str += (
f", Val {reward_func.__name__}_mean {val_metrics[f'{reward_func.__name__}_mean']:.3f}, "
@ -622,7 +621,6 @@ def train_grpo(
f"KL {avg_metrics['kl']:.3f}"
)
# Add reward function specific metrics
for i, reward_func in enumerate(reward_funcs):
func_name = reward_func.__name__
train_metrics_str += (
@ -656,7 +654,6 @@ def train_grpo(
steps = 0
start = time.perf_counter()
# Save adapter weights
if it % args.steps_per_save == 0:
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
mx.save_safetensors(str(args.adapter_file), adapter_weights)
@ -669,7 +666,6 @@ def train_grpo(
f"{args.adapter_file} and {checkpoint}."
)
# Save final weights
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
mx.save_safetensors(str(args.adapter_file), adapter_weights)
print(f"Saved final weights to {args.adapter_file}.")