removing comments + adding temperature + reward weighting

This commit is contained in:
Goekdeniz-Guelmez
2025-02-15 15:29:22 +01:00
parent baeb9f117f
commit 5ec4790656
2 changed files with 64 additions and 52 deletions

View File

@@ -35,13 +35,24 @@ class GRPOTrainingArgs(TrainingArgs):
"help": "Path to reference model weights. If None, uses the same model."
}
)
temperature: float = field(
default=1.0,
metadata={
"help": "Temperature for sampling. The higher the temperature, the more random the completions."
}
)
reward_weights: Optional[List[float]] = field(
default=None,
metadata={
"help": "Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are weighted equally with weight `1.0`."
}
)
RewardFunctions = Callable[[List[str], List[str], List[str]], List[float]]
def r1_extract_xml_answer(text: str) -> str:
"""Extracts the answer from an XML formatted text string."""
try:
answer = text.split("<answer>")[-1]
answer = answer.split("</answer>")[0]
@@ -52,14 +63,12 @@ def r1_extract_xml_answer(text: str) -> str:
def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
"""Ensures we always return a list of floats."""
if not completions:
return [0.0] * len(prompts)
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
return [0.5 if r and r.isdigit() else 0.0 for r in extracted_responses]
def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
"""Ensures we always return a list of floats."""
if not completions or not answer:
return [0.0] * len(prompts)
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
@@ -67,7 +76,6 @@ def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kw
def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
"""Ensures we always return a list of floats."""
if not completions:
return [0.0] * len(prompts)
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
@@ -76,7 +84,6 @@ def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, *
def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
"""Ensures we always return a list of floats."""
if not completions:
return [0.0] * len(prompts)
pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>\n$"
@@ -85,16 +92,13 @@ def r1_strict_format_reward_func(prompts: list, completions: list, answer: list,
def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
"""Ensures we always return a list of floats."""
if not completions:
return [0.0] * len(prompts)
scores = []
for text in completions:
if not text:
scores.append(0.0)
continue
count = 0.0
if text.count("<think>\n") == 1:
count += 0.125
@@ -104,13 +108,9 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li
count += 0.125
if text.count("\n</answer>\n") == 1:
count += 0.125
# Penalize extra text after </answer>
end_text = text.split("\n</answer>\n")[-1]
count -= len(end_text) * 0.001 if len(end_text) > 0 else 0
scores.append(max(0.0, count)) # Ensure non-negative score
scores.append(max(0.0, count))
return scores
@@ -119,22 +119,18 @@ def generate_grpo(model: nn.Module, prompt, max_tokens, tokenizer, temperature):
prompt = prompt[None, :]
if prompt.shape[1] == 0:
return None
end_sequence = tokenizer.encode("</answer>")
end_sequence_length = len(end_sequence)
initial_length = prompt.shape[1]
output = mx.zeros((initial_length + max_tokens,), dtype=mx.int32)
output[:initial_length] = prompt[0]
current_length = initial_length
try:
def sample(logits):
if temperature > 0:
logits /= temperature
logprobs = logits - mx.logsumexp(logits, keepdims=True)
return mx.random.categorical(logprobs[None, :]).astype(mx.int32)[0]
for _ in range(max_tokens):
current_input = output[:current_length][None, :]
logits = model(current_input)
@@ -143,18 +139,14 @@ def generate_grpo(model: nn.Module, prompt, max_tokens, tokenizer, temperature):
token_value = next_token.item()
output[current_length] = token_value
current_length += 1
if token_value == tokenizer.eos_token_id:
break
if current_length >= end_sequence_length:
last_tokens = output[current_length - end_sequence_length:current_length].tolist()
if last_tokens == end_sequence:
break
if current_length > initial_length:
return output[:current_length]
except Exception as e:
print(f"Generation error: {str(e)}")
return None
@@ -192,9 +184,10 @@ def grpo_loss(
group_size=4,
epsilon=1e-4,
max_tokens=64,
temperature=1.0
temperature=1.0,
reward_weights=None
):
prompt_tokens, answer_tokens, prompt_text, answer_text = batch
prompt_tokens, _, prompt_text, answer_text = batch
batch_size = len(prompt_tokens)
all_completions = []
@@ -273,18 +266,34 @@ def grpo_loss(
token_log_probs = mx.stack(padded_log_probs)
ref_token_log_probs = mx.stack(padded_ref_log_probs)
# Rewards and advantages
rewards = mx.zeros((len(all_completions),))
# Create array to store rewards from each function
all_func_rewards = []
# Collect rewards from each function separately
for reward_func in reward_funcs:
func_rewards = mx.array(reward_func(
prompts=expanded_prompts,
completions=all_completion_texts,
answer=expanded_answers
))
rewards += func_rewards
all_func_rewards.append(func_rewards)
if len(reward_funcs) > 1:
rewards /= len(reward_funcs)
# Stack rewards to shape (num_samples, num_funcs)
rewards = mx.stack(all_func_rewards, axis=1)
print(f"Rewards: {rewards}")
# Apply weights and sum
if reward_weights is not None:
if len(reward_weights) != len(reward_funcs):
raise ValueError(
f"Number of reward weights ({len(reward_weights)}) must match number of reward "
f"functions ({len(reward_funcs)})"
)
reward_weights = mx.array(reward_weights, dtype=mx.float32)
else:
reward_weights = mx.ones(len(reward_funcs), dtype=mx.float32)
rewards = (rewards * mx.expand_dims(reward_weights, 0)).sum(axis=1)
print(f"Rewards after weights: {rewards}")
# Reshape rewards and compute advantages
rewards_reshaped = rewards.reshape(batch_size, group_size)
@@ -397,15 +406,11 @@ def evaluate_grpo(
epsilon: float,
group_size: int,
max_seq_length,
temperature: float,
reward_funcs: Optional[List[RewardFunctions]] = None,
loss_fn: callable = grpo_loss,
iterate_batches: callable = iterate_grpo_batches
):
"""
Evaluate model using GRPO loss.
Returns:
tuple: (average loss, number of tokens, average metrics)
"""
all_losses = 0
ntokens = 0
all_metrics = None
@@ -428,7 +433,8 @@ def evaluate_grpo(
beta=beta,
group_size=group_size,
epsilon=epsilon,
ref_model=ref_model
ref_model=ref_model,
temperature=temperature
)
all_losses += losses * toks
@@ -442,12 +448,10 @@ def evaluate_grpo(
mx.eval(all_losses, ntokens)
# Aggregate across distributed workers
all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu)
ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu)
all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()}
# Calculate averages
avg_metrics = {k: (v / ntokens).item() for k, v in all_metrics.items()}
avg_loss = (all_losses / ntokens).item()
@@ -486,8 +490,6 @@ def train_grpo(
state = [model.state, optimizer.state]
def step(batch):
# Forward and backward pass
(loss, toks, metrics), grad = loss_value_and_grad(
model,
tokenizer=tokenizer,
@@ -498,12 +500,11 @@ def train_grpo(
epsilon=args.epsilon,
ref_model=ref_model,
max_tokens=args.max_completion_length,
temperature=args.temperature
)
# All reduce the gradients if running in distributed mode
grad = average_gradients(grad)
# Model update
optimizer.update(model, grad)
return loss, toks, metrics
@@ -536,8 +537,6 @@ def train_grpo(
train=True,
),
):
# Report validation loss if needed, the first validation loss
# is always measured before any training.
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
stop = time.perf_counter()
val_loss, val_ntokens, val_metrics = evaluate_grpo(
@@ -553,6 +552,7 @@ def train_grpo(
max_seq_length=args.max_seq_length,
beta=args.beta,
epsilon=args.epsilon,
temperature=args.temperature,
iterate_batches=iterate_batches,
)
val_time = time.perf_counter() - stop
@@ -566,7 +566,6 @@ def train_grpo(
f"Val kl {val_metrics['kl']:.3f}"
)
# Add reward function specific metrics
for i, reward_func in enumerate(reward_funcs):
val_metrics_str += (
f", Val {reward_func.__name__}_mean {val_metrics[f'{reward_func.__name__}_mean']:.3f}, "
@@ -622,7 +621,6 @@ def train_grpo(
f"KL {avg_metrics['kl']:.3f}"
)
# Add reward function specific metrics
for i, reward_func in enumerate(reward_funcs):
func_name = reward_func.__name__
train_metrics_str += (
@@ -656,7 +654,6 @@ def train_grpo(
steps = 0
start = time.perf_counter()
# Save adapter weights
if it % args.steps_per_save == 0:
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
mx.save_safetensors(str(args.adapter_file), adapter_weights)
@@ -669,7 +666,6 @@ def train_grpo(
f"{args.adapter_file} and {checkpoint}."
)
# Save final weights
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
mx.save_safetensors(str(args.adapter_file), adapter_weights)
print(f"Saved final weights to {args.adapter_file}.")