mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-27 19:31:20 +08:00
removing comments + adding temperature + reward weighting
This commit is contained in:
parent
baeb9f117f
commit
5ec4790656
@ -1,21 +1,21 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
import types
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import types
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
import yaml
|
||||
|
||||
from .tuner.grpo_trainer import GRPOTrainingArgs, evaluate_grpo, train_grpo
|
||||
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
||||
from .tokenizer_utils import TokenizerWrapper
|
||||
from .tuner.datasets import load_dataset
|
||||
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
||||
from .tuner.grpo_trainer import GRPOTrainingArgs, evaluate_grpo, train_grpo
|
||||
from .tuner.utils import (
|
||||
build_schedule,
|
||||
linear_to_lora_layers,
|
||||
@ -73,6 +73,8 @@ CONFIG_DEFAULTS = {
|
||||
"max_completion_length": 512,
|
||||
"use_chat_template": False,
|
||||
"use_prompt": False,
|
||||
"temperature": 1.0,
|
||||
"reward_weights": None,
|
||||
}
|
||||
|
||||
|
||||
@ -224,6 +226,18 @@ def build_parser():
|
||||
help="Rather to use the prompt from the R1 paper.",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
help="Temperature for sampling. The higher the temperature, the more random the completions.",
|
||||
default=1.0,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reward-weights",
|
||||
type=str,
|
||||
help="Weights for each reward function. Must match the number of reward functions and be in this format [0.1, 0.2, 0.3, 0.4, 0.5]. If not given, all rewards are weighted equally with weight `1.0`.",
|
||||
default=None,
|
||||
)
|
||||
return parser
|
||||
|
||||
def train_model_grpo(model, tokenizer, args, opt, train_set, valid_set, adapter_file, training_callback):
|
||||
@ -241,7 +255,9 @@ def train_model_grpo(model, tokenizer, args, opt, train_set, valid_set, adapter_
|
||||
beta=args.beta,
|
||||
group_size=args.group_size,
|
||||
epsilon=args.epsilon,
|
||||
reference_model_path=args.reference_model_path
|
||||
reference_model_path=args.reference_model_path,
|
||||
temperature=args.temperature,
|
||||
reward_weights=[float(x) for x in args.reward_weights.strip('[]').split(',')] if args.reward_weights else None
|
||||
)
|
||||
|
||||
if args.reference_model_path:
|
||||
|
@ -35,13 +35,24 @@ class GRPOTrainingArgs(TrainingArgs):
|
||||
"help": "Path to reference model weights. If None, uses the same model."
|
||||
}
|
||||
)
|
||||
temperature: float = field(
|
||||
default=1.0,
|
||||
metadata={
|
||||
"help": "Temperature for sampling. The higher the temperature, the more random the completions."
|
||||
}
|
||||
)
|
||||
reward_weights: Optional[List[float]] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are weighted equally with weight `1.0`."
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
RewardFunctions = Callable[[List[str], List[str], List[str]], List[float]]
|
||||
|
||||
|
||||
def r1_extract_xml_answer(text: str) -> str:
|
||||
"""Extracts the answer from an XML formatted text string."""
|
||||
try:
|
||||
answer = text.split("<answer>")[-1]
|
||||
answer = answer.split("</answer>")[0]
|
||||
@ -52,14 +63,12 @@ def r1_extract_xml_answer(text: str) -> str:
|
||||
|
||||
|
||||
def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||
"""Ensures we always return a list of floats."""
|
||||
if not completions:
|
||||
return [0.0] * len(prompts)
|
||||
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
||||
return [0.5 if r and r.isdigit() else 0.0 for r in extracted_responses]
|
||||
|
||||
def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||
"""Ensures we always return a list of floats."""
|
||||
if not completions or not answer:
|
||||
return [0.0] * len(prompts)
|
||||
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
||||
@ -67,7 +76,6 @@ def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kw
|
||||
|
||||
|
||||
def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||
"""Ensures we always return a list of floats."""
|
||||
if not completions:
|
||||
return [0.0] * len(prompts)
|
||||
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
|
||||
@ -76,7 +84,6 @@ def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, *
|
||||
|
||||
|
||||
def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||
"""Ensures we always return a list of floats."""
|
||||
if not completions:
|
||||
return [0.0] * len(prompts)
|
||||
pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>\n$"
|
||||
@ -85,16 +92,13 @@ def r1_strict_format_reward_func(prompts: list, completions: list, answer: list,
|
||||
|
||||
|
||||
def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||
"""Ensures we always return a list of floats."""
|
||||
if not completions:
|
||||
return [0.0] * len(prompts)
|
||||
|
||||
scores = []
|
||||
for text in completions:
|
||||
if not text:
|
||||
scores.append(0.0)
|
||||
continue
|
||||
|
||||
count = 0.0
|
||||
if text.count("<think>\n") == 1:
|
||||
count += 0.125
|
||||
@ -104,13 +108,9 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li
|
||||
count += 0.125
|
||||
if text.count("\n</answer>\n") == 1:
|
||||
count += 0.125
|
||||
|
||||
# Penalize extra text after </answer>
|
||||
end_text = text.split("\n</answer>\n")[-1]
|
||||
count -= len(end_text) * 0.001 if len(end_text) > 0 else 0
|
||||
|
||||
scores.append(max(0.0, count)) # Ensure non-negative score
|
||||
|
||||
scores.append(max(0.0, count))
|
||||
return scores
|
||||
|
||||
|
||||
@ -119,22 +119,18 @@ def generate_grpo(model: nn.Module, prompt, max_tokens, tokenizer, temperature):
|
||||
prompt = prompt[None, :]
|
||||
if prompt.shape[1] == 0:
|
||||
return None
|
||||
|
||||
end_sequence = tokenizer.encode("</answer>")
|
||||
end_sequence_length = len(end_sequence)
|
||||
|
||||
initial_length = prompt.shape[1]
|
||||
output = mx.zeros((initial_length + max_tokens,), dtype=mx.int32)
|
||||
output[:initial_length] = prompt[0]
|
||||
current_length = initial_length
|
||||
|
||||
try:
|
||||
def sample(logits):
|
||||
if temperature > 0:
|
||||
logits /= temperature
|
||||
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
||||
return mx.random.categorical(logprobs[None, :]).astype(mx.int32)[0]
|
||||
|
||||
for _ in range(max_tokens):
|
||||
current_input = output[:current_length][None, :]
|
||||
logits = model(current_input)
|
||||
@ -143,18 +139,14 @@ def generate_grpo(model: nn.Module, prompt, max_tokens, tokenizer, temperature):
|
||||
token_value = next_token.item()
|
||||
output[current_length] = token_value
|
||||
current_length += 1
|
||||
|
||||
if token_value == tokenizer.eos_token_id:
|
||||
break
|
||||
|
||||
if current_length >= end_sequence_length:
|
||||
last_tokens = output[current_length - end_sequence_length:current_length].tolist()
|
||||
if last_tokens == end_sequence:
|
||||
break
|
||||
|
||||
if current_length > initial_length:
|
||||
return output[:current_length]
|
||||
|
||||
except Exception as e:
|
||||
print(f"Generation error: {str(e)}")
|
||||
return None
|
||||
@ -192,9 +184,10 @@ def grpo_loss(
|
||||
group_size=4,
|
||||
epsilon=1e-4,
|
||||
max_tokens=64,
|
||||
temperature=1.0
|
||||
temperature=1.0,
|
||||
reward_weights=None
|
||||
):
|
||||
prompt_tokens, answer_tokens, prompt_text, answer_text = batch
|
||||
prompt_tokens, _, prompt_text, answer_text = batch
|
||||
batch_size = len(prompt_tokens)
|
||||
|
||||
all_completions = []
|
||||
@ -273,18 +266,34 @@ def grpo_loss(
|
||||
token_log_probs = mx.stack(padded_log_probs)
|
||||
ref_token_log_probs = mx.stack(padded_ref_log_probs)
|
||||
|
||||
# Rewards and advantages
|
||||
rewards = mx.zeros((len(all_completions),))
|
||||
# Create array to store rewards from each function
|
||||
all_func_rewards = []
|
||||
|
||||
# Collect rewards from each function separately
|
||||
for reward_func in reward_funcs:
|
||||
func_rewards = mx.array(reward_func(
|
||||
prompts=expanded_prompts,
|
||||
completions=all_completion_texts,
|
||||
answer=expanded_answers
|
||||
))
|
||||
rewards += func_rewards
|
||||
all_func_rewards.append(func_rewards)
|
||||
|
||||
if len(reward_funcs) > 1:
|
||||
rewards /= len(reward_funcs)
|
||||
# Stack rewards to shape (num_samples, num_funcs)
|
||||
rewards = mx.stack(all_func_rewards, axis=1)
|
||||
print(f"Rewards: {rewards}")
|
||||
|
||||
# Apply weights and sum
|
||||
if reward_weights is not None:
|
||||
if len(reward_weights) != len(reward_funcs):
|
||||
raise ValueError(
|
||||
f"Number of reward weights ({len(reward_weights)}) must match number of reward "
|
||||
f"functions ({len(reward_funcs)})"
|
||||
)
|
||||
reward_weights = mx.array(reward_weights, dtype=mx.float32)
|
||||
else:
|
||||
reward_weights = mx.ones(len(reward_funcs), dtype=mx.float32)
|
||||
rewards = (rewards * mx.expand_dims(reward_weights, 0)).sum(axis=1)
|
||||
print(f"Rewards after weights: {rewards}")
|
||||
|
||||
# Reshape rewards and compute advantages
|
||||
rewards_reshaped = rewards.reshape(batch_size, group_size)
|
||||
@ -397,15 +406,11 @@ def evaluate_grpo(
|
||||
epsilon: float,
|
||||
group_size: int,
|
||||
max_seq_length,
|
||||
temperature: float,
|
||||
reward_funcs: Optional[List[RewardFunctions]] = None,
|
||||
loss_fn: callable = grpo_loss,
|
||||
iterate_batches: callable = iterate_grpo_batches
|
||||
):
|
||||
"""
|
||||
Evaluate model using GRPO loss.
|
||||
Returns:
|
||||
tuple: (average loss, number of tokens, average metrics)
|
||||
"""
|
||||
all_losses = 0
|
||||
ntokens = 0
|
||||
all_metrics = None
|
||||
@ -428,7 +433,8 @@ def evaluate_grpo(
|
||||
beta=beta,
|
||||
group_size=group_size,
|
||||
epsilon=epsilon,
|
||||
ref_model=ref_model
|
||||
ref_model=ref_model,
|
||||
temperature=temperature
|
||||
)
|
||||
|
||||
all_losses += losses * toks
|
||||
@ -442,12 +448,10 @@ def evaluate_grpo(
|
||||
|
||||
mx.eval(all_losses, ntokens)
|
||||
|
||||
# Aggregate across distributed workers
|
||||
all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu)
|
||||
ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu)
|
||||
all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()}
|
||||
|
||||
# Calculate averages
|
||||
avg_metrics = {k: (v / ntokens).item() for k, v in all_metrics.items()}
|
||||
avg_loss = (all_losses / ntokens).item()
|
||||
|
||||
@ -486,8 +490,6 @@ def train_grpo(
|
||||
state = [model.state, optimizer.state]
|
||||
|
||||
def step(batch):
|
||||
|
||||
# Forward and backward pass
|
||||
(loss, toks, metrics), grad = loss_value_and_grad(
|
||||
model,
|
||||
tokenizer=tokenizer,
|
||||
@ -498,12 +500,11 @@ def train_grpo(
|
||||
epsilon=args.epsilon,
|
||||
ref_model=ref_model,
|
||||
max_tokens=args.max_completion_length,
|
||||
temperature=args.temperature
|
||||
)
|
||||
|
||||
# All reduce the gradients if running in distributed mode
|
||||
grad = average_gradients(grad)
|
||||
|
||||
# Model update
|
||||
optimizer.update(model, grad)
|
||||
|
||||
return loss, toks, metrics
|
||||
@ -536,8 +537,6 @@ def train_grpo(
|
||||
train=True,
|
||||
),
|
||||
):
|
||||
# Report validation loss if needed, the first validation loss
|
||||
# is always measured before any training.
|
||||
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
|
||||
stop = time.perf_counter()
|
||||
val_loss, val_ntokens, val_metrics = evaluate_grpo(
|
||||
@ -553,6 +552,7 @@ def train_grpo(
|
||||
max_seq_length=args.max_seq_length,
|
||||
beta=args.beta,
|
||||
epsilon=args.epsilon,
|
||||
temperature=args.temperature,
|
||||
iterate_batches=iterate_batches,
|
||||
)
|
||||
val_time = time.perf_counter() - stop
|
||||
@ -566,7 +566,6 @@ def train_grpo(
|
||||
f"Val kl {val_metrics['kl']:.3f}"
|
||||
)
|
||||
|
||||
# Add reward function specific metrics
|
||||
for i, reward_func in enumerate(reward_funcs):
|
||||
val_metrics_str += (
|
||||
f", Val {reward_func.__name__}_mean {val_metrics[f'{reward_func.__name__}_mean']:.3f}, "
|
||||
@ -622,7 +621,6 @@ def train_grpo(
|
||||
f"KL {avg_metrics['kl']:.3f}"
|
||||
)
|
||||
|
||||
# Add reward function specific metrics
|
||||
for i, reward_func in enumerate(reward_funcs):
|
||||
func_name = reward_func.__name__
|
||||
train_metrics_str += (
|
||||
@ -656,7 +654,6 @@ def train_grpo(
|
||||
steps = 0
|
||||
start = time.perf_counter()
|
||||
|
||||
# Save adapter weights
|
||||
if it % args.steps_per_save == 0:
|
||||
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
|
||||
mx.save_safetensors(str(args.adapter_file), adapter_weights)
|
||||
@ -669,7 +666,6 @@ def train_grpo(
|
||||
f"{args.adapter_file} and {checkpoint}."
|
||||
)
|
||||
|
||||
# Save final weights
|
||||
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
|
||||
mx.save_safetensors(str(args.adapter_file), adapter_weights)
|
||||
print(f"Saved final weights to {args.adapter_file}.")
|
Loading…
Reference in New Issue
Block a user