removing comments + adding temperature + reward weighting

This commit is contained in:
Goekdeniz-Guelmez 2025-02-15 15:29:22 +01:00
parent baeb9f117f
commit 5ec4790656
2 changed files with 64 additions and 52 deletions

View File

@ -1,21 +1,21 @@
# Copyright © 2024 Apple Inc. # Copyright © 2024 Apple Inc.
from pathlib import Path
import argparse import argparse
import types
import math import math
import os import os
import re import re
import types
from pathlib import Path
import mlx.nn as nn
import mlx.optimizers as optim import mlx.optimizers as optim
import mlx.nn as nn
import numpy as np import numpy as np
import yaml import yaml
from .tuner.grpo_trainer import GRPOTrainingArgs, evaluate_grpo, train_grpo
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tokenizer_utils import TokenizerWrapper from .tokenizer_utils import TokenizerWrapper
from .tuner.datasets import load_dataset from .tuner.datasets import load_dataset
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tuner.grpo_trainer import GRPOTrainingArgs, evaluate_grpo, train_grpo
from .tuner.utils import ( from .tuner.utils import (
build_schedule, build_schedule,
linear_to_lora_layers, linear_to_lora_layers,
@ -73,6 +73,8 @@ CONFIG_DEFAULTS = {
"max_completion_length": 512, "max_completion_length": 512,
"use_chat_template": False, "use_chat_template": False,
"use_prompt": False, "use_prompt": False,
"temperature": 1.0,
"reward_weights": None,
} }
@ -224,6 +226,18 @@ def build_parser():
help="Rather to use the prompt from the R1 paper.", help="Rather to use the prompt from the R1 paper.",
default=None, default=None,
) )
parser.add_argument(
"--temperature",
type=float,
help="Temperature for sampling. The higher the temperature, the more random the completions.",
default=1.0,
)
parser.add_argument(
"--reward-weights",
type=str,
help="Weights for each reward function. Must match the number of reward functions and be in this format [0.1, 0.2, 0.3, 0.4, 0.5]. If not given, all rewards are weighted equally with weight `1.0`.",
default=None,
)
return parser return parser
def train_model_grpo(model, tokenizer, args, opt, train_set, valid_set, adapter_file, training_callback): def train_model_grpo(model, tokenizer, args, opt, train_set, valid_set, adapter_file, training_callback):
@ -241,7 +255,9 @@ def train_model_grpo(model, tokenizer, args, opt, train_set, valid_set, adapter_
beta=args.beta, beta=args.beta,
group_size=args.group_size, group_size=args.group_size,
epsilon=args.epsilon, epsilon=args.epsilon,
reference_model_path=args.reference_model_path reference_model_path=args.reference_model_path,
temperature=args.temperature,
reward_weights=[float(x) for x in args.reward_weights.strip('[]').split(',')] if args.reward_weights else None
) )
if args.reference_model_path: if args.reference_model_path:

View File

@ -35,13 +35,24 @@ class GRPOTrainingArgs(TrainingArgs):
"help": "Path to reference model weights. If None, uses the same model." "help": "Path to reference model weights. If None, uses the same model."
} }
) )
temperature: float = field(
default=1.0,
metadata={
"help": "Temperature for sampling. The higher the temperature, the more random the completions."
}
)
reward_weights: Optional[List[float]] = field(
default=None,
metadata={
"help": "Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are weighted equally with weight `1.0`."
}
)
RewardFunctions = Callable[[List[str], List[str], List[str]], List[float]] RewardFunctions = Callable[[List[str], List[str], List[str]], List[float]]
def r1_extract_xml_answer(text: str) -> str: def r1_extract_xml_answer(text: str) -> str:
"""Extracts the answer from an XML formatted text string."""
try: try:
answer = text.split("<answer>")[-1] answer = text.split("<answer>")[-1]
answer = answer.split("</answer>")[0] answer = answer.split("</answer>")[0]
@ -52,14 +63,12 @@ def r1_extract_xml_answer(text: str) -> str:
def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
"""Ensures we always return a list of floats."""
if not completions: if not completions:
return [0.0] * len(prompts) return [0.0] * len(prompts)
extracted_responses = [r1_extract_xml_answer(r) for r in completions] extracted_responses = [r1_extract_xml_answer(r) for r in completions]
return [0.5 if r and r.isdigit() else 0.0 for r in extracted_responses] return [0.5 if r and r.isdigit() else 0.0 for r in extracted_responses]
def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
"""Ensures we always return a list of floats."""
if not completions or not answer: if not completions or not answer:
return [0.0] * len(prompts) return [0.0] * len(prompts)
extracted_responses = [r1_extract_xml_answer(r) for r in completions] extracted_responses = [r1_extract_xml_answer(r) for r in completions]
@ -67,7 +76,6 @@ def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kw
def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
"""Ensures we always return a list of floats."""
if not completions: if not completions:
return [0.0] * len(prompts) return [0.0] * len(prompts)
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>" pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
@ -76,7 +84,6 @@ def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, *
def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
"""Ensures we always return a list of floats."""
if not completions: if not completions:
return [0.0] * len(prompts) return [0.0] * len(prompts)
pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>\n$" pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>\n$"
@ -85,16 +92,13 @@ def r1_strict_format_reward_func(prompts: list, completions: list, answer: list,
def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
"""Ensures we always return a list of floats."""
if not completions: if not completions:
return [0.0] * len(prompts) return [0.0] * len(prompts)
scores = [] scores = []
for text in completions: for text in completions:
if not text: if not text:
scores.append(0.0) scores.append(0.0)
continue continue
count = 0.0 count = 0.0
if text.count("<think>\n") == 1: if text.count("<think>\n") == 1:
count += 0.125 count += 0.125
@ -104,13 +108,9 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li
count += 0.125 count += 0.125
if text.count("\n</answer>\n") == 1: if text.count("\n</answer>\n") == 1:
count += 0.125 count += 0.125
# Penalize extra text after </answer>
end_text = text.split("\n</answer>\n")[-1] end_text = text.split("\n</answer>\n")[-1]
count -= len(end_text) * 0.001 if len(end_text) > 0 else 0 count -= len(end_text) * 0.001 if len(end_text) > 0 else 0
scores.append(max(0.0, count))
scores.append(max(0.0, count)) # Ensure non-negative score
return scores return scores
@ -119,22 +119,18 @@ def generate_grpo(model: nn.Module, prompt, max_tokens, tokenizer, temperature):
prompt = prompt[None, :] prompt = prompt[None, :]
if prompt.shape[1] == 0: if prompt.shape[1] == 0:
return None return None
end_sequence = tokenizer.encode("</answer>") end_sequence = tokenizer.encode("</answer>")
end_sequence_length = len(end_sequence) end_sequence_length = len(end_sequence)
initial_length = prompt.shape[1] initial_length = prompt.shape[1]
output = mx.zeros((initial_length + max_tokens,), dtype=mx.int32) output = mx.zeros((initial_length + max_tokens,), dtype=mx.int32)
output[:initial_length] = prompt[0] output[:initial_length] = prompt[0]
current_length = initial_length current_length = initial_length
try: try:
def sample(logits): def sample(logits):
if temperature > 0: if temperature > 0:
logits /= temperature logits /= temperature
logprobs = logits - mx.logsumexp(logits, keepdims=True) logprobs = logits - mx.logsumexp(logits, keepdims=True)
return mx.random.categorical(logprobs[None, :]).astype(mx.int32)[0] return mx.random.categorical(logprobs[None, :]).astype(mx.int32)[0]
for _ in range(max_tokens): for _ in range(max_tokens):
current_input = output[:current_length][None, :] current_input = output[:current_length][None, :]
logits = model(current_input) logits = model(current_input)
@ -143,18 +139,14 @@ def generate_grpo(model: nn.Module, prompt, max_tokens, tokenizer, temperature):
token_value = next_token.item() token_value = next_token.item()
output[current_length] = token_value output[current_length] = token_value
current_length += 1 current_length += 1
if token_value == tokenizer.eos_token_id: if token_value == tokenizer.eos_token_id:
break break
if current_length >= end_sequence_length: if current_length >= end_sequence_length:
last_tokens = output[current_length - end_sequence_length:current_length].tolist() last_tokens = output[current_length - end_sequence_length:current_length].tolist()
if last_tokens == end_sequence: if last_tokens == end_sequence:
break break
if current_length > initial_length: if current_length > initial_length:
return output[:current_length] return output[:current_length]
except Exception as e: except Exception as e:
print(f"Generation error: {str(e)}") print(f"Generation error: {str(e)}")
return None return None
@ -192,9 +184,10 @@ def grpo_loss(
group_size=4, group_size=4,
epsilon=1e-4, epsilon=1e-4,
max_tokens=64, max_tokens=64,
temperature=1.0 temperature=1.0,
reward_weights=None
): ):
prompt_tokens, answer_tokens, prompt_text, answer_text = batch prompt_tokens, _, prompt_text, answer_text = batch
batch_size = len(prompt_tokens) batch_size = len(prompt_tokens)
all_completions = [] all_completions = []
@ -273,18 +266,34 @@ def grpo_loss(
token_log_probs = mx.stack(padded_log_probs) token_log_probs = mx.stack(padded_log_probs)
ref_token_log_probs = mx.stack(padded_ref_log_probs) ref_token_log_probs = mx.stack(padded_ref_log_probs)
# Rewards and advantages # Create array to store rewards from each function
rewards = mx.zeros((len(all_completions),)) all_func_rewards = []
# Collect rewards from each function separately
for reward_func in reward_funcs: for reward_func in reward_funcs:
func_rewards = mx.array(reward_func( func_rewards = mx.array(reward_func(
prompts=expanded_prompts, prompts=expanded_prompts,
completions=all_completion_texts, completions=all_completion_texts,
answer=expanded_answers answer=expanded_answers
)) ))
rewards += func_rewards all_func_rewards.append(func_rewards)
if len(reward_funcs) > 1: # Stack rewards to shape (num_samples, num_funcs)
rewards /= len(reward_funcs) rewards = mx.stack(all_func_rewards, axis=1)
print(f"Rewards: {rewards}")
# Apply weights and sum
if reward_weights is not None:
if len(reward_weights) != len(reward_funcs):
raise ValueError(
f"Number of reward weights ({len(reward_weights)}) must match number of reward "
f"functions ({len(reward_funcs)})"
)
reward_weights = mx.array(reward_weights, dtype=mx.float32)
else:
reward_weights = mx.ones(len(reward_funcs), dtype=mx.float32)
rewards = (rewards * mx.expand_dims(reward_weights, 0)).sum(axis=1)
print(f"Rewards after weights: {rewards}")
# Reshape rewards and compute advantages # Reshape rewards and compute advantages
rewards_reshaped = rewards.reshape(batch_size, group_size) rewards_reshaped = rewards.reshape(batch_size, group_size)
@ -397,15 +406,11 @@ def evaluate_grpo(
epsilon: float, epsilon: float,
group_size: int, group_size: int,
max_seq_length, max_seq_length,
temperature: float,
reward_funcs: Optional[List[RewardFunctions]] = None, reward_funcs: Optional[List[RewardFunctions]] = None,
loss_fn: callable = grpo_loss, loss_fn: callable = grpo_loss,
iterate_batches: callable = iterate_grpo_batches iterate_batches: callable = iterate_grpo_batches
): ):
"""
Evaluate model using GRPO loss.
Returns:
tuple: (average loss, number of tokens, average metrics)
"""
all_losses = 0 all_losses = 0
ntokens = 0 ntokens = 0
all_metrics = None all_metrics = None
@ -428,7 +433,8 @@ def evaluate_grpo(
beta=beta, beta=beta,
group_size=group_size, group_size=group_size,
epsilon=epsilon, epsilon=epsilon,
ref_model=ref_model ref_model=ref_model,
temperature=temperature
) )
all_losses += losses * toks all_losses += losses * toks
@ -442,12 +448,10 @@ def evaluate_grpo(
mx.eval(all_losses, ntokens) mx.eval(all_losses, ntokens)
# Aggregate across distributed workers
all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu) all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu)
ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu) ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu)
all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()} all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()}
# Calculate averages
avg_metrics = {k: (v / ntokens).item() for k, v in all_metrics.items()} avg_metrics = {k: (v / ntokens).item() for k, v in all_metrics.items()}
avg_loss = (all_losses / ntokens).item() avg_loss = (all_losses / ntokens).item()
@ -486,8 +490,6 @@ def train_grpo(
state = [model.state, optimizer.state] state = [model.state, optimizer.state]
def step(batch): def step(batch):
# Forward and backward pass
(loss, toks, metrics), grad = loss_value_and_grad( (loss, toks, metrics), grad = loss_value_and_grad(
model, model,
tokenizer=tokenizer, tokenizer=tokenizer,
@ -498,12 +500,11 @@ def train_grpo(
epsilon=args.epsilon, epsilon=args.epsilon,
ref_model=ref_model, ref_model=ref_model,
max_tokens=args.max_completion_length, max_tokens=args.max_completion_length,
temperature=args.temperature
) )
# All reduce the gradients if running in distributed mode
grad = average_gradients(grad) grad = average_gradients(grad)
# Model update
optimizer.update(model, grad) optimizer.update(model, grad)
return loss, toks, metrics return loss, toks, metrics
@ -536,8 +537,6 @@ def train_grpo(
train=True, train=True,
), ),
): ):
# Report validation loss if needed, the first validation loss
# is always measured before any training.
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters: if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
stop = time.perf_counter() stop = time.perf_counter()
val_loss, val_ntokens, val_metrics = evaluate_grpo( val_loss, val_ntokens, val_metrics = evaluate_grpo(
@ -553,6 +552,7 @@ def train_grpo(
max_seq_length=args.max_seq_length, max_seq_length=args.max_seq_length,
beta=args.beta, beta=args.beta,
epsilon=args.epsilon, epsilon=args.epsilon,
temperature=args.temperature,
iterate_batches=iterate_batches, iterate_batches=iterate_batches,
) )
val_time = time.perf_counter() - stop val_time = time.perf_counter() - stop
@ -566,7 +566,6 @@ def train_grpo(
f"Val kl {val_metrics['kl']:.3f}" f"Val kl {val_metrics['kl']:.3f}"
) )
# Add reward function specific metrics
for i, reward_func in enumerate(reward_funcs): for i, reward_func in enumerate(reward_funcs):
val_metrics_str += ( val_metrics_str += (
f", Val {reward_func.__name__}_mean {val_metrics[f'{reward_func.__name__}_mean']:.3f}, " f", Val {reward_func.__name__}_mean {val_metrics[f'{reward_func.__name__}_mean']:.3f}, "
@ -622,7 +621,6 @@ def train_grpo(
f"KL {avg_metrics['kl']:.3f}" f"KL {avg_metrics['kl']:.3f}"
) )
# Add reward function specific metrics
for i, reward_func in enumerate(reward_funcs): for i, reward_func in enumerate(reward_funcs):
func_name = reward_func.__name__ func_name = reward_func.__name__
train_metrics_str += ( train_metrics_str += (
@ -656,7 +654,6 @@ def train_grpo(
steps = 0 steps = 0
start = time.perf_counter() start = time.perf_counter()
# Save adapter weights
if it % args.steps_per_save == 0: if it % args.steps_per_save == 0:
adapter_weights = dict(tree_flatten(model.trainable_parameters())) adapter_weights = dict(tree_flatten(model.trainable_parameters()))
mx.save_safetensors(str(args.adapter_file), adapter_weights) mx.save_safetensors(str(args.adapter_file), adapter_weights)
@ -669,7 +666,6 @@ def train_grpo(
f"{args.adapter_file} and {checkpoint}." f"{args.adapter_file} and {checkpoint}."
) )
# Save final weights
adapter_weights = dict(tree_flatten(model.trainable_parameters())) adapter_weights = dict(tree_flatten(model.trainable_parameters()))
mx.save_safetensors(str(args.adapter_file), adapter_weights) mx.save_safetensors(str(args.adapter_file), adapter_weights)
print(f"Saved final weights to {args.adapter_file}.") print(f"Saved final weights to {args.adapter_file}.")