smoll fix

This commit is contained in:
Goekdeniz-Guelmez 2025-02-05 11:30:21 +01:00
parent a33cad84b4
commit 35a2d99cf9

View File

@ -34,6 +34,7 @@ class GRPOTrainingArgs(TrainingArgs):
}
)
def r1_extract_xml_answer(text: str) -> str:
"""Extracts the answer from an XML formatted text string."""
try:
@ -41,35 +42,50 @@ def r1_extract_xml_answer(text: str) -> str:
answer = answer.split("</answer>")[0]
return answer.strip()
except:
print("[extract_xml_answer] Failed to extract answer from: ", text)
print("r1_extract_xml_answer returned empty string")
return ""
def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
"""Ensures we always return a list of floats."""
if not completions: # Handle empty completions
return [0.0] * len(prompts)
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
return [0.5 if r and r.isdigit() else 0.0 for r in extracted_responses]
def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
"""Ensures we always return a list of floats."""
if not completions or not answer: # Handle empty inputs
return [0.0] * len(prompts)
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
return [2.0 if r and a and r == a else 0.0 for r, a in zip(extracted_responses, answer)]
def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
"""Rewards completions with flexible XML format."""
"""Ensures we always return a list of floats."""
if not completions: # Handle empty completions
return [0.0] * len(prompts)
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
matches = [re.match(pattern, r) for r in completions]
matches = [bool(re.search(pattern, r)) if r else False for r in completions]
return [0.5 if match else 0.0 for match in matches]
def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
"""Ensures we always return a list of floats."""
if not completions: # Handle empty completions
return [0.0] * len(prompts)
pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>\n$"
matches = [re.match(pattern, r) for r in completions]
matches = [bool(re.search(pattern, r)) if r else False for r in completions]
return [0.5 if match else 0.0 for match in matches]
def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
"""Ensures we always return a list of floats."""
if not completions: # Handle empty completions
return [0.0] * len(prompts)
scores = []
for text in completions:
if not text: # Handle None or empty text
scores.append(0.0)
continue
count = 0.0
if text.count("<think>\n") == 1:
count += 0.125
@ -77,11 +93,16 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li
count += 0.125
if text.count("\n<answer>\n") == 1:
count += 0.125
count -= len(text.split("\n</answer>\n")[-1])*0.001
if text.count("\n</answer>") == 1:
if text.count("\n</answer>\n") == 1:
count += 0.125
count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
scores.append(count)
# Penalize extra text after </answer>
end_text = text.split("\n</answer>\n")[-1]
count -= len(end_text) * 0.001 if len(end_text) > 0 else 0
scores.append(max(0.0, count)) # Ensure non-negative score
return scores
def generate_grpo(model, prompt, max_tokens, tokenizer, temperature):