diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py
index 75d5207f..a9ba4b01 100644
--- a/llms/mlx_lm/tuner/grpo_trainer.py
+++ b/llms/mlx_lm/tuner/grpo_trainer.py
@@ -45,44 +45,48 @@ def r1_extract_xml_answer(text: str) -> str:
print("r1_extract_xml_answer returned empty string")
return ""
+
def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
"""Ensures we always return a list of floats."""
- if not completions: # Handle empty completions
+ if not completions:
return [0.0] * len(prompts)
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
return [0.5 if r and r.isdigit() else 0.0 for r in extracted_responses]
def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
"""Ensures we always return a list of floats."""
- if not completions or not answer: # Handle empty inputs
+ if not completions or not answer:
return [0.0] * len(prompts)
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
return [2.0 if r and a and r == a else 0.0 for r, a in zip(extracted_responses, answer)]
+
def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
"""Ensures we always return a list of floats."""
- if not completions: # Handle empty completions
+ if not completions:
return [0.0] * len(prompts)
pattern = r".*?\s*.*?"
matches = [bool(re.search(pattern, r)) if r else False for r in completions]
return [0.5 if match else 0.0 for match in matches]
+
def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
"""Ensures we always return a list of floats."""
- if not completions: # Handle empty completions
+ if not completions:
return [0.0] * len(prompts)
pattern = r"^\n.*?\n\n\n.*?\n\n$"
matches = [bool(re.search(pattern, r)) if r else False for r in completions]
return [0.5 if match else 0.0 for match in matches]
+
def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
"""Ensures we always return a list of floats."""
- if not completions: # Handle empty completions
+ if not completions:
return [0.0] * len(prompts)
scores = []
for text in completions:
- if not text: # Handle None or empty text
+ if not text:
scores.append(0.0)
continue
@@ -137,11 +141,9 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature):
output[current_length] = token_value
current_length += 1
- # Check for EOS token
if token_value == tokenizer.eos_token_id:
break
- # Check for "" sequence
if current_length >= end_sequence_length:
last_tokens = output[current_length - end_sequence_length:current_length].tolist()
if last_tokens == end_sequence:
@@ -255,7 +257,6 @@ def grpo_loss(
mx.eval(token_log_probs)
mx.metal.clear_cache()
-
# Reference policy probabilities
if ref_model is not None:
ref_token_log_probs = get_per_token_logps(ref_model, inputs, lengths)
@@ -305,11 +306,12 @@ def grpo_loss(
policy_ratio = mx.exp(token_log_probs - mx.stop_gradient(token_log_probs))
# Compute per-token loss following GRPO formula
- per_token_loss = -(policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask
+ per_token_loss = -((policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask)
# Average over tokens and sequences
sequence_sums = per_token_loss.sum(axis=1)
sequence_lengths = length_mask.sum(axis=1)
+
loss = (sequence_sums / sequence_lengths).mean()
# Calculate mean KL divergence for metrics