This commit is contained in:
Goekdeniz-Guelmez 2025-03-05 12:59:41 +01:00
parent c817743333
commit 3dfb21267b
2 changed files with 432 additions and 297 deletions

View File

@ -1,6 +1,5 @@
from typing import List, Optional, Callable
import re import re
from typing import Callable, List, Optional
RewardFunctions = Callable[[List[str], List[str], List[str]], List[float]] RewardFunctions = Callable[[List[str], List[str], List[str]], List[float]]
@ -14,52 +13,73 @@ def r1_extract_xml_answer(text: str) -> str:
print("r1_extract_xml_answer returned empty string") print("r1_extract_xml_answer returned empty string")
return "" return ""
def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
def r1_int_reward_func(
prompts: list, completions: list, answer: list, **kwargs
) -> list[float]:
if not completions: if not completions:
return [0.0] * len(prompts) return [0.0] * len(prompts)
extracted_responses = [r1_extract_xml_answer(r) for r in completions] extracted_responses = [r1_extract_xml_answer(r) for r in completions]
return [0.5 if r and r.isdigit() else 0.0 for r in extracted_responses] return [0.5 if r and r.isdigit() else 0.0 for r in extracted_responses]
def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
def r1_accuracy_reward_func(
prompts: list, completions: list, answer: list, **kwargs
) -> list[float]:
if not completions or not answer: if not completions or not answer:
return [0.0] * len(prompts) return [0.0] * len(prompts)
extracted_responses = [r1_extract_xml_answer(r) for r in completions] extracted_responses = [r1_extract_xml_answer(r) for r in completions]
return [2.0 if r and a and r == a else 0.0 for r, a in zip(extracted_responses, answer)] return [
2.0 if r and a and r == a else 0.0 for r, a in zip(extracted_responses, answer)
]
def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
def r1_soft_format_reward_func(
prompts: list, completions: list, answer: list, **kwargs
) -> list[float]:
if not completions: if not completions:
return [0.0] * len(prompts) return [0.0] * len(prompts)
scores = [] scores = []
for completion in completions: for completion in completions:
if not completion: if not completion:
scores.append(0.0) scores.append(0.0)
continue continue
reason_start = completion.find("<think>") reason_start = completion.find("<think>")
reason_end = completion.find("</think>") reason_end = completion.find("</think>")
answer_start = completion.find("<answer>") answer_start = completion.find("<answer>")
answer_end = completion.find("</answer>") answer_end = completion.find("</answer>")
if (reason_start != -1 and reason_end != -1 and if (
answer_start != -1 and answer_end != -1 and reason_start != -1
reason_start < reason_end < answer_start < answer_end): and reason_end != -1
reason_content = completion[reason_start+13:reason_end].strip() and answer_start != -1
answer_content = completion[answer_start+8:answer_end].strip() and answer_end != -1
and reason_start < reason_end < answer_start < answer_end
):
reason_content = completion[reason_start + 13 : reason_end].strip()
answer_content = completion[answer_start + 8 : answer_end].strip()
if reason_content and answer_content: if reason_content and answer_content:
scores.append(0.5) scores.append(0.5)
continue continue
scores.append(0.0) scores.append(0.0)
return scores return scores
def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
def r1_strict_format_reward_func(
prompts: list, completions: list, answer: list, **kwargs
) -> list[float]:
if not completions: if not completions:
return [0.0] * len(prompts) return [0.0] * len(prompts)
pattern = r"<think> .*? </think><answer> .*? </answer>" pattern = r"<think> .*? </think><answer> .*? </answer>"
matches = [bool(re.search(pattern, r)) if r else False for r in completions] matches = [bool(re.search(pattern, r)) if r else False for r in completions]
return [0.5 if match else 0.0 for match in matches] return [0.5 if match else 0.0 for match in matches]
def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
def r1_count_xml(
prompts: list, completions: list, answer: list, **kwargs
) -> list[float]:
if not completions: if not completions:
return [0.0] * len(prompts) return [0.0] * len(prompts)
scores = [] scores = []
@ -79,4 +99,4 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li
end_text = text.split("</answer>")[-1] end_text = text.split("</answer>")[-1]
count -= len(end_text) * 0.001 if len(end_text) > 0 else 0 count -= len(end_text) * 0.001 if len(end_text) > 0 else 0
scores.append(max(0.0, count)) scores.append(max(0.0, count))
return scores return scores

File diff suppressed because it is too large Load Diff