diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index 64b0bc49..75d5207f 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -34,6 +34,7 @@ class GRPOTrainingArgs(TrainingArgs): } ) + def r1_extract_xml_answer(text: str) -> str: """Extracts the answer from an XML formatted text string.""" try: @@ -41,35 +42,50 @@ def r1_extract_xml_answer(text: str) -> str: answer = answer.split("")[0] return answer.strip() except: - print("[extract_xml_answer] Failed to extract answer from: ", text) + print("r1_extract_xml_answer returned empty string") return "" def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: + """Ensures we always return a list of floats.""" + if not completions: # Handle empty completions + return [0.0] * len(prompts) extracted_responses = [r1_extract_xml_answer(r) for r in completions] - return [0.5 if r.isdigit() else 0.0 for r in extracted_responses] + return [0.5 if r and r.isdigit() else 0.0 for r in extracted_responses] def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: + """Ensures we always return a list of floats.""" + if not completions or not answer: # Handle empty inputs + return [0.0] * len(prompts) extracted_responses = [r1_extract_xml_answer(r) for r in completions] - return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)] + return [2.0 if r and a and r == a else 0.0 for r, a in zip(extracted_responses, answer)] def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: - """Rewards completions with flexible XML format.""" + """Ensures we always return a list of floats.""" + if not completions: # Handle empty completions + return [0.0] * len(prompts) pattern = r".*?\s*.*?" - matches = [re.match(pattern, r) for r in completions] + matches = [bool(re.search(pattern, r)) if r else False for r in completions] return [0.5 if match else 0.0 for match in matches] -def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: - extracted_responses = [r1_extract_xml_answer(r) for r in completions] - return [0.5 if r.isdigit() else 0.0 for r in extracted_responses] - def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: + """Ensures we always return a list of floats.""" + if not completions: # Handle empty completions + return [0.0] * len(prompts) pattern = r"^\n.*?\n\n\n.*?\n\n$" - matches = [re.match(pattern, r) for r in completions] + matches = [bool(re.search(pattern, r)) if r else False for r in completions] return [0.5 if match else 0.0 for match in matches] def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: + """Ensures we always return a list of floats.""" + if not completions: # Handle empty completions + return [0.0] * len(prompts) + scores = [] for text in completions: + if not text: # Handle None or empty text + scores.append(0.0) + continue + count = 0.0 if text.count("\n") == 1: count += 0.125 @@ -77,11 +93,16 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li count += 0.125 if text.count("\n\n") == 1: count += 0.125 - count -= len(text.split("\n\n")[-1])*0.001 - if text.count("\n") == 1: + if text.count("\n\n") == 1: count += 0.125 - count -= (len(text.split("\n")[-1]) - 1)*0.001 - scores.append(count) + + # Penalize extra text after + end_text = text.split("\n\n")[-1] + count -= len(end_text) * 0.001 if len(end_text) > 0 else 0 + + scores.append(max(0.0, count)) # Ensure non-negative score + + return scores def generate_grpo(model, prompt, max_tokens, tokenizer, temperature):