This commit is contained in:
Goekdeniz-Guelmez 2025-02-05 14:38:09 +01:00
parent 35a2d99cf9
commit 0a19522ec4

View File

@ -45,44 +45,48 @@ def r1_extract_xml_answer(text: str) -> str:
print("r1_extract_xml_answer returned empty string")
return ""
def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
"""Ensures we always return a list of floats."""
if not completions: # Handle empty completions
if not completions:
return [0.0] * len(prompts)
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
return [0.5 if r and r.isdigit() else 0.0 for r in extracted_responses]
def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
"""Ensures we always return a list of floats."""
if not completions or not answer: # Handle empty inputs
if not completions or not answer:
return [0.0] * len(prompts)
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
return [2.0 if r and a and r == a else 0.0 for r, a in zip(extracted_responses, answer)]
def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
"""Ensures we always return a list of floats."""
if not completions: # Handle empty completions
if not completions:
return [0.0] * len(prompts)
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
matches = [bool(re.search(pattern, r)) if r else False for r in completions]
return [0.5 if match else 0.0 for match in matches]
def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
"""Ensures we always return a list of floats."""
if not completions: # Handle empty completions
if not completions:
return [0.0] * len(prompts)
pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>\n$"
matches = [bool(re.search(pattern, r)) if r else False for r in completions]
return [0.5 if match else 0.0 for match in matches]
def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
"""Ensures we always return a list of floats."""
if not completions: # Handle empty completions
if not completions:
return [0.0] * len(prompts)
scores = []
for text in completions:
if not text: # Handle None or empty text
if not text:
scores.append(0.0)
continue
@ -137,11 +141,9 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature):
output[current_length] = token_value
current_length += 1
# Check for EOS token
if token_value == tokenizer.eos_token_id:
break
# Check for "</answer>" sequence
if current_length >= end_sequence_length:
last_tokens = output[current_length - end_sequence_length:current_length].tolist()
if last_tokens == end_sequence:
@ -255,7 +257,6 @@ def grpo_loss(
mx.eval(token_log_probs)
mx.metal.clear_cache()
# Reference policy probabilities
if ref_model is not None:
ref_token_log_probs = get_per_token_logps(ref_model, inputs, lengths)
@ -305,11 +306,12 @@ def grpo_loss(
policy_ratio = mx.exp(token_log_probs - mx.stop_gradient(token_log_probs))
# Compute per-token loss following GRPO formula
per_token_loss = -(policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask
per_token_loss = -((policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask)
# Average over tokens and sequences
sequence_sums = per_token_loss.sum(axis=1)
sequence_lengths = length_mask.sum(axis=1)
loss = (sequence_sums / sequence_lengths).mean()
# Calculate mean KL divergence for metrics