mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-29 04:31:13 +08:00
updates
This commit is contained in:
parent
c817743333
commit
3dfb21267b
@ -1,6 +1,5 @@
|
|||||||
from typing import List, Optional, Callable
|
|
||||||
import re
|
import re
|
||||||
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
RewardFunctions = Callable[[List[str], List[str], List[str]], List[float]]
|
RewardFunctions = Callable[[List[str], List[str], List[str]], List[float]]
|
||||||
|
|
||||||
@ -14,52 +13,73 @@ def r1_extract_xml_answer(text: str) -> str:
|
|||||||
print("r1_extract_xml_answer returned empty string")
|
print("r1_extract_xml_answer returned empty string")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
|
||||||
|
def r1_int_reward_func(
|
||||||
|
prompts: list, completions: list, answer: list, **kwargs
|
||||||
|
) -> list[float]:
|
||||||
if not completions:
|
if not completions:
|
||||||
return [0.0] * len(prompts)
|
return [0.0] * len(prompts)
|
||||||
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
||||||
return [0.5 if r and r.isdigit() else 0.0 for r in extracted_responses]
|
return [0.5 if r and r.isdigit() else 0.0 for r in extracted_responses]
|
||||||
|
|
||||||
def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
|
||||||
|
def r1_accuracy_reward_func(
|
||||||
|
prompts: list, completions: list, answer: list, **kwargs
|
||||||
|
) -> list[float]:
|
||||||
if not completions or not answer:
|
if not completions or not answer:
|
||||||
return [0.0] * len(prompts)
|
return [0.0] * len(prompts)
|
||||||
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
||||||
return [2.0 if r and a and r == a else 0.0 for r, a in zip(extracted_responses, answer)]
|
return [
|
||||||
|
2.0 if r and a and r == a else 0.0 for r, a in zip(extracted_responses, answer)
|
||||||
|
]
|
||||||
|
|
||||||
def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
|
||||||
|
def r1_soft_format_reward_func(
|
||||||
|
prompts: list, completions: list, answer: list, **kwargs
|
||||||
|
) -> list[float]:
|
||||||
if not completions:
|
if not completions:
|
||||||
return [0.0] * len(prompts)
|
return [0.0] * len(prompts)
|
||||||
|
|
||||||
scores = []
|
scores = []
|
||||||
for completion in completions:
|
for completion in completions:
|
||||||
if not completion:
|
if not completion:
|
||||||
scores.append(0.0)
|
scores.append(0.0)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
reason_start = completion.find("<think>")
|
reason_start = completion.find("<think>")
|
||||||
reason_end = completion.find("</think>")
|
reason_end = completion.find("</think>")
|
||||||
answer_start = completion.find("<answer>")
|
answer_start = completion.find("<answer>")
|
||||||
answer_end = completion.find("</answer>")
|
answer_end = completion.find("</answer>")
|
||||||
|
|
||||||
if (reason_start != -1 and reason_end != -1 and
|
if (
|
||||||
answer_start != -1 and answer_end != -1 and
|
reason_start != -1
|
||||||
reason_start < reason_end < answer_start < answer_end):
|
and reason_end != -1
|
||||||
reason_content = completion[reason_start+13:reason_end].strip()
|
and answer_start != -1
|
||||||
answer_content = completion[answer_start+8:answer_end].strip()
|
and answer_end != -1
|
||||||
|
and reason_start < reason_end < answer_start < answer_end
|
||||||
|
):
|
||||||
|
reason_content = completion[reason_start + 13 : reason_end].strip()
|
||||||
|
answer_content = completion[answer_start + 8 : answer_end].strip()
|
||||||
if reason_content and answer_content:
|
if reason_content and answer_content:
|
||||||
scores.append(0.5)
|
scores.append(0.5)
|
||||||
continue
|
continue
|
||||||
scores.append(0.0)
|
scores.append(0.0)
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
|
||||||
|
def r1_strict_format_reward_func(
|
||||||
|
prompts: list, completions: list, answer: list, **kwargs
|
||||||
|
) -> list[float]:
|
||||||
if not completions:
|
if not completions:
|
||||||
return [0.0] * len(prompts)
|
return [0.0] * len(prompts)
|
||||||
pattern = r"<think> .*? </think><answer> .*? </answer>"
|
pattern = r"<think> .*? </think><answer> .*? </answer>"
|
||||||
matches = [bool(re.search(pattern, r)) if r else False for r in completions]
|
matches = [bool(re.search(pattern, r)) if r else False for r in completions]
|
||||||
return [0.5 if match else 0.0 for match in matches]
|
return [0.5 if match else 0.0 for match in matches]
|
||||||
|
|
||||||
def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
|
||||||
|
def r1_count_xml(
|
||||||
|
prompts: list, completions: list, answer: list, **kwargs
|
||||||
|
) -> list[float]:
|
||||||
if not completions:
|
if not completions:
|
||||||
return [0.0] * len(prompts)
|
return [0.0] * len(prompts)
|
||||||
scores = []
|
scores = []
|
||||||
@ -79,4 +99,4 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li
|
|||||||
end_text = text.split("</answer>")[-1]
|
end_text = text.split("</answer>")[-1]
|
||||||
count -= len(end_text) * 0.001 if len(end_text) > 0 else 0
|
count -= len(end_text) * 0.001 if len(end_text) > 0 else 0
|
||||||
scores.append(max(0.0, count))
|
scores.append(max(0.0, count))
|
||||||
return scores
|
return scores
|
||||||
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user