diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py
index 64b0bc49..75d5207f 100644
--- a/llms/mlx_lm/tuner/grpo_trainer.py
+++ b/llms/mlx_lm/tuner/grpo_trainer.py
@@ -34,6 +34,7 @@ class GRPOTrainingArgs(TrainingArgs):
}
)
+
def r1_extract_xml_answer(text: str) -> str:
"""Extracts the answer from an XML formatted text string."""
try:
@@ -41,35 +42,50 @@ def r1_extract_xml_answer(text: str) -> str:
answer = answer.split("")[0]
return answer.strip()
except:
- print("[extract_xml_answer] Failed to extract answer from: ", text)
+ print("r1_extract_xml_answer returned empty string")
return ""
def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
+ """Ensures we always return a list of floats."""
+ if not completions: # Handle empty completions
+ return [0.0] * len(prompts)
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
- return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
+ return [0.5 if r and r.isdigit() else 0.0 for r in extracted_responses]
def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
+ """Ensures we always return a list of floats."""
+ if not completions or not answer: # Handle empty inputs
+ return [0.0] * len(prompts)
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
- return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
+ return [2.0 if r and a and r == a else 0.0 for r, a in zip(extracted_responses, answer)]
def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
- """Rewards completions with flexible XML format."""
+ """Ensures we always return a list of floats."""
+ if not completions: # Handle empty completions
+ return [0.0] * len(prompts)
pattern = r".*?\s*.*?"
- matches = [re.match(pattern, r) for r in completions]
+ matches = [bool(re.search(pattern, r)) if r else False for r in completions]
return [0.5 if match else 0.0 for match in matches]
-def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
- extracted_responses = [r1_extract_xml_answer(r) for r in completions]
- return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
-
def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
+ """Ensures we always return a list of floats."""
+ if not completions: # Handle empty completions
+ return [0.0] * len(prompts)
pattern = r"^\n.*?\n\n\n.*?\n\n$"
- matches = [re.match(pattern, r) for r in completions]
+ matches = [bool(re.search(pattern, r)) if r else False for r in completions]
return [0.5 if match else 0.0 for match in matches]
def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
+ """Ensures we always return a list of floats."""
+ if not completions: # Handle empty completions
+ return [0.0] * len(prompts)
+
scores = []
for text in completions:
+ if not text: # Handle None or empty text
+ scores.append(0.0)
+ continue
+
count = 0.0
if text.count("\n") == 1:
count += 0.125
@@ -77,11 +93,16 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li
count += 0.125
if text.count("\n\n") == 1:
count += 0.125
- count -= len(text.split("\n\n")[-1])*0.001
- if text.count("\n") == 1:
+ if text.count("\n\n") == 1:
count += 0.125
- count -= (len(text.split("\n")[-1]) - 1)*0.001
- scores.append(count)
+
+ # Penalize extra text after
+ end_text = text.split("\n\n")[-1]
+ count -= len(end_text) * 0.001 if len(end_text) > 0 else 0
+
+ scores.append(max(0.0, count)) # Ensure non-negative score
+
+ return scores
def generate_grpo(model, prompt, max_tokens, tokenizer, temperature):