mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-13 21:06:38 +08:00
smoll fix
This commit is contained in:
parent
a33cad84b4
commit
35a2d99cf9
@ -34,6 +34,7 @@ class GRPOTrainingArgs(TrainingArgs):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def r1_extract_xml_answer(text: str) -> str:
|
def r1_extract_xml_answer(text: str) -> str:
|
||||||
"""Extracts the answer from an XML formatted text string."""
|
"""Extracts the answer from an XML formatted text string."""
|
||||||
try:
|
try:
|
||||||
@ -41,35 +42,50 @@ def r1_extract_xml_answer(text: str) -> str:
|
|||||||
answer = answer.split("</answer>")[0]
|
answer = answer.split("</answer>")[0]
|
||||||
return answer.strip()
|
return answer.strip()
|
||||||
except:
|
except:
|
||||||
print("[extract_xml_answer] Failed to extract answer from: ", text)
|
print("r1_extract_xml_answer returned empty string")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||||
|
"""Ensures we always return a list of floats."""
|
||||||
|
if not completions: # Handle empty completions
|
||||||
|
return [0.0] * len(prompts)
|
||||||
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
||||||
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
|
return [0.5 if r and r.isdigit() else 0.0 for r in extracted_responses]
|
||||||
|
|
||||||
def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||||
|
"""Ensures we always return a list of floats."""
|
||||||
|
if not completions or not answer: # Handle empty inputs
|
||||||
|
return [0.0] * len(prompts)
|
||||||
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
||||||
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
|
return [2.0 if r and a and r == a else 0.0 for r, a in zip(extracted_responses, answer)]
|
||||||
|
|
||||||
def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||||
"""Rewards completions with flexible XML format."""
|
"""Ensures we always return a list of floats."""
|
||||||
|
if not completions: # Handle empty completions
|
||||||
|
return [0.0] * len(prompts)
|
||||||
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
|
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
|
||||||
matches = [re.match(pattern, r) for r in completions]
|
matches = [bool(re.search(pattern, r)) if r else False for r in completions]
|
||||||
return [0.5 if match else 0.0 for match in matches]
|
return [0.5 if match else 0.0 for match in matches]
|
||||||
|
|
||||||
def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
|
||||||
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
|
||||||
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
|
|
||||||
|
|
||||||
def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||||
|
"""Ensures we always return a list of floats."""
|
||||||
|
if not completions: # Handle empty completions
|
||||||
|
return [0.0] * len(prompts)
|
||||||
pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>\n$"
|
pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>\n$"
|
||||||
matches = [re.match(pattern, r) for r in completions]
|
matches = [bool(re.search(pattern, r)) if r else False for r in completions]
|
||||||
return [0.5 if match else 0.0 for match in matches]
|
return [0.5 if match else 0.0 for match in matches]
|
||||||
|
|
||||||
def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||||
|
"""Ensures we always return a list of floats."""
|
||||||
|
if not completions: # Handle empty completions
|
||||||
|
return [0.0] * len(prompts)
|
||||||
|
|
||||||
scores = []
|
scores = []
|
||||||
for text in completions:
|
for text in completions:
|
||||||
|
if not text: # Handle None or empty text
|
||||||
|
scores.append(0.0)
|
||||||
|
continue
|
||||||
|
|
||||||
count = 0.0
|
count = 0.0
|
||||||
if text.count("<think>\n") == 1:
|
if text.count("<think>\n") == 1:
|
||||||
count += 0.125
|
count += 0.125
|
||||||
@ -77,11 +93,16 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li
|
|||||||
count += 0.125
|
count += 0.125
|
||||||
if text.count("\n<answer>\n") == 1:
|
if text.count("\n<answer>\n") == 1:
|
||||||
count += 0.125
|
count += 0.125
|
||||||
count -= len(text.split("\n</answer>\n")[-1])*0.001
|
if text.count("\n</answer>\n") == 1:
|
||||||
if text.count("\n</answer>") == 1:
|
|
||||||
count += 0.125
|
count += 0.125
|
||||||
count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
|
|
||||||
scores.append(count)
|
# Penalize extra text after </answer>
|
||||||
|
end_text = text.split("\n</answer>\n")[-1]
|
||||||
|
count -= len(end_text) * 0.001 if len(end_text) > 0 else 0
|
||||||
|
|
||||||
|
scores.append(max(0.0, count)) # Ensure non-negative score
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
def generate_grpo(model, prompt, max_tokens, tokenizer, temperature):
|
def generate_grpo(model, prompt, max_tokens, tokenizer, temperature):
|
||||||
|
Loading…
Reference in New Issue
Block a user