From 0a19522ec4a10460afa8fe5a048e23d3da14440c Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Wed, 5 Feb 2025 14:38:09 +0100 Subject: [PATCH] updates --- llms/mlx_lm/tuner/grpo_trainer.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index 75d5207f..a9ba4b01 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -45,44 +45,48 @@ def r1_extract_xml_answer(text: str) -> str: print("r1_extract_xml_answer returned empty string") return "" + def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: """Ensures we always return a list of floats.""" - if not completions: # Handle empty completions + if not completions: return [0.0] * len(prompts) extracted_responses = [r1_extract_xml_answer(r) for r in completions] return [0.5 if r and r.isdigit() else 0.0 for r in extracted_responses] def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: """Ensures we always return a list of floats.""" - if not completions or not answer: # Handle empty inputs + if not completions or not answer: return [0.0] * len(prompts) extracted_responses = [r1_extract_xml_answer(r) for r in completions] return [2.0 if r and a and r == a else 0.0 for r, a in zip(extracted_responses, answer)] + def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: """Ensures we always return a list of floats.""" - if not completions: # Handle empty completions + if not completions: return [0.0] * len(prompts) pattern = r".*?\s*.*?" matches = [bool(re.search(pattern, r)) if r else False for r in completions] return [0.5 if match else 0.0 for match in matches] + def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: """Ensures we always return a list of floats.""" - if not completions: # Handle empty completions + if not completions: return [0.0] * len(prompts) pattern = r"^\n.*?\n\n\n.*?\n\n$" matches = [bool(re.search(pattern, r)) if r else False for r in completions] return [0.5 if match else 0.0 for match in matches] + def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: """Ensures we always return a list of floats.""" - if not completions: # Handle empty completions + if not completions: return [0.0] * len(prompts) scores = [] for text in completions: - if not text: # Handle None or empty text + if not text: scores.append(0.0) continue @@ -137,11 +141,9 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature): output[current_length] = token_value current_length += 1 - # Check for EOS token if token_value == tokenizer.eos_token_id: break - # Check for "" sequence if current_length >= end_sequence_length: last_tokens = output[current_length - end_sequence_length:current_length].tolist() if last_tokens == end_sequence: @@ -255,7 +257,6 @@ def grpo_loss( mx.eval(token_log_probs) mx.metal.clear_cache() - # Reference policy probabilities if ref_model is not None: ref_token_log_probs = get_per_token_logps(ref_model, inputs, lengths) @@ -305,11 +306,12 @@ def grpo_loss( policy_ratio = mx.exp(token_log_probs - mx.stop_gradient(token_log_probs)) # Compute per-token loss following GRPO formula - per_token_loss = -(policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask + per_token_loss = -((policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask) # Average over tokens and sequences sequence_sums = per_token_loss.sum(axis=1) sequence_lengths = length_mask.sum(axis=1) + loss = (sequence_sums / sequence_lengths).mean() # Calculate mean KL divergence for metrics