mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-10 19:26:46 +08:00
smoll fix
This commit is contained in:
parent
a33cad84b4
commit
35a2d99cf9
@ -34,6 +34,7 @@ class GRPOTrainingArgs(TrainingArgs):
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def r1_extract_xml_answer(text: str) -> str:
|
||||
"""Extracts the answer from an XML formatted text string."""
|
||||
try:
|
||||
@ -41,35 +42,50 @@ def r1_extract_xml_answer(text: str) -> str:
|
||||
answer = answer.split("</answer>")[0]
|
||||
return answer.strip()
|
||||
except:
|
||||
print("[extract_xml_answer] Failed to extract answer from: ", text)
|
||||
print("r1_extract_xml_answer returned empty string")
|
||||
return ""
|
||||
|
||||
def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||
"""Ensures we always return a list of floats."""
|
||||
if not completions: # Handle empty completions
|
||||
return [0.0] * len(prompts)
|
||||
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
||||
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
|
||||
return [0.5 if r and r.isdigit() else 0.0 for r in extracted_responses]
|
||||
|
||||
def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||
"""Ensures we always return a list of floats."""
|
||||
if not completions or not answer: # Handle empty inputs
|
||||
return [0.0] * len(prompts)
|
||||
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
||||
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
|
||||
return [2.0 if r and a and r == a else 0.0 for r, a in zip(extracted_responses, answer)]
|
||||
|
||||
def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||
"""Rewards completions with flexible XML format."""
|
||||
"""Ensures we always return a list of floats."""
|
||||
if not completions: # Handle empty completions
|
||||
return [0.0] * len(prompts)
|
||||
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
|
||||
matches = [re.match(pattern, r) for r in completions]
|
||||
matches = [bool(re.search(pattern, r)) if r else False for r in completions]
|
||||
return [0.5 if match else 0.0 for match in matches]
|
||||
|
||||
def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
||||
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
|
||||
|
||||
def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||
"""Ensures we always return a list of floats."""
|
||||
if not completions: # Handle empty completions
|
||||
return [0.0] * len(prompts)
|
||||
pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>\n$"
|
||||
matches = [re.match(pattern, r) for r in completions]
|
||||
matches = [bool(re.search(pattern, r)) if r else False for r in completions]
|
||||
return [0.5 if match else 0.0 for match in matches]
|
||||
|
||||
def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||
"""Ensures we always return a list of floats."""
|
||||
if not completions: # Handle empty completions
|
||||
return [0.0] * len(prompts)
|
||||
|
||||
scores = []
|
||||
for text in completions:
|
||||
if not text: # Handle None or empty text
|
||||
scores.append(0.0)
|
||||
continue
|
||||
|
||||
count = 0.0
|
||||
if text.count("<think>\n") == 1:
|
||||
count += 0.125
|
||||
@ -77,11 +93,16 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li
|
||||
count += 0.125
|
||||
if text.count("\n<answer>\n") == 1:
|
||||
count += 0.125
|
||||
count -= len(text.split("\n</answer>\n")[-1])*0.001
|
||||
if text.count("\n</answer>") == 1:
|
||||
if text.count("\n</answer>\n") == 1:
|
||||
count += 0.125
|
||||
count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
|
||||
scores.append(count)
|
||||
|
||||
# Penalize extra text after </answer>
|
||||
end_text = text.split("\n</answer>\n")[-1]
|
||||
count -= len(end_text) * 0.001 if len(end_text) > 0 else 0
|
||||
|
||||
scores.append(max(0.0, count)) # Ensure non-negative score
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
def generate_grpo(model, prompt, max_tokens, tokenizer, temperature):
|
||||
|
Loading…
Reference in New Issue
Block a user