diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py
index b948ae01..68fa93da 100644
--- a/llms/mlx_lm/lora.py
+++ b/llms/mlx_lm/lora.py
@@ -206,15 +206,15 @@ def build_parser():
)
parser.add_argument(
"--use-chat-template",
- type=bool,
+ action="store_true",
help="If the model is a Chat model, use the Chat template.",
- default=False,
+ default=None,
)
parser.add_argument(
"--use-prompt",
- type=bool,
- help="Rather to use the prompt from teh R1 paper.",
- default=False,
+ action="store_true",
+ help="Rather to use the prompt from the R1 paper.",
+ default=None,
)
return parser
diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py
index 0210b44a..64b0bc49 100644
--- a/llms/mlx_lm/tuner/grpo_trainer.py
+++ b/llms/mlx_lm/tuner/grpo_trainer.py
@@ -12,7 +12,6 @@ from mlx.utils import tree_flatten
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients, iterate_batches
-
@dataclass
class GRPOTrainingArgs(TrainingArgs):
group_size: int = field(
@@ -35,7 +34,6 @@ class GRPOTrainingArgs(TrainingArgs):
}
)
-
def r1_extract_xml_answer(text: str) -> str:
"""Extracts the answer from an XML formatted text string."""
try:
@@ -46,62 +44,30 @@ def r1_extract_xml_answer(text: str) -> str:
print("[extract_xml_answer] Failed to extract answer from: ", text)
return ""
-def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
- """Calculates reward based on accuracy of extracted answers.
- Args:
- prompts: List of input prompts
- completions: List of completion strings
- answer: Expected answer or list of answers
- **kwargs: Additional arguments
- Returns:
- list[float]: Reward values for each completion
- """
- extracted_responses = [r1_extract_xml_answer(r) for r in completions]
- return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
-
def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
- """Rewards numerical responses.
- Args:
- prompts: List of input prompts
- completions: List of completion strings
- answer: Expected answer or list of answers
- **kwargs: Additional arguments
- Returns:
- list[float]: Reward values for each completion
- """
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
+def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
+ extracted_responses = [r1_extract_xml_answer(r) for r in completions]
+ return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
+
def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
"""Rewards completions with flexible XML format."""
pattern = r".*?\s*.*?"
matches = [re.match(pattern, r) for r in completions]
return [0.5 if match else 0.0 for match in matches]
+def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
+ extracted_responses = [r1_extract_xml_answer(r) for r in completions]
+ return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
+
def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
- """Rewards completions with strict XML format.
- Args:
- prompts: List of input prompts
- completions: List of completion strings
- answer: Expected answer or list of answers
- **kwargs: Additional arguments
- Returns:
- list[float]: Reward values for each completion
- """
pattern = r"^\n.*?\n\n\n.*?\n\n$"
matches = [re.match(pattern, r) for r in completions]
return [0.5 if match else 0.0 for match in matches]
def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
- """Calculates score based on XML formatting.
- Args:
- prompts: List of input prompts (unused)
- completions: List of completion strings to evaluate
- answer: Expected answer or list of answers (unused)
- **kwargs: Additional arguments
- Returns:
- list[float]: List of scores based on XML tag presence and formatting
- """
scores = []
for text in completions:
count = 0.0
@@ -116,10 +82,9 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li
count += 0.125
count -= (len(text.split("\n")[-1]) - 1)*0.001
scores.append(count)
- return scores
-def generate_grpo(model, prompt, max_tokens, tokenizer, temperature=1.0):
+def generate_grpo(model, prompt, max_tokens, tokenizer, temperature):
if len(prompt.shape) == 1:
prompt = prompt[None, :]
if prompt.shape[1] == 0:
@@ -172,30 +137,24 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature=1.0):
def get_per_token_logps(model, inputs, lengths):
- logits = model(inputs).astype(mx.float16) # [batch_size, seq_len, vocab_size]
- # Remove last position as it corresponds to the next token prediction
- logits = logits[:, :-1, :] # [batch_size, seq_len-1, vocab_size]
- targets = inputs[:, 1:] # Shift inputs to get targets
+ logits = model(inputs).astype(mx.float16)
+ logits = logits[:, :-1, :]
+ targets = inputs[:, 1:]
- # Process sequences individually to save memory
per_token_logps = []
for i in range(logits.shape[0]):
- # Get sequence length for this example
- seq_len = int(lengths[i]) - 1 # -1 because we removed last position
+ seq_len = int(lengths[i]) - 1
- # Get logits and targets for this sequence
- seq_logits = logits[i, :seq_len] # [seq_len, vocab_size]
- seq_targets = targets[i, :seq_len] # [seq_len]
+ seq_logits = logits[i, :seq_len]
+ seq_targets = targets[i, :seq_len]
- # Compute log probabilities
- log_probs = nn.log_softmax(seq_logits, axis=-1) # [seq_len, vocab_size]
+ log_probs = nn.log_softmax(seq_logits, axis=-1)
- # Gather log probs for actual tokens
token_log_probs = mx.take_along_axis(
log_probs,
seq_targets.reshape(seq_len, 1),
axis=-1
- ).squeeze(-1) # [seq_len]
+ ).squeeze(-1)
per_token_logps.append(token_log_probs)
mx.eval(logits)
@@ -316,7 +275,7 @@ def grpo_loss(
advantages = (rewards - mean_rewards) / (std_rewards + epsilon)
# Compute KL divergence using Schulman's approximator
- kl_div = mx.exp(ref_token_log_probs - token_log_probs) - (ref_token_log_probs - token_log_probs) - 1
+ kl_div = (mx.exp(token_log_probs - ref_token_log_probs) - 1) - (token_log_probs - ref_token_log_probs)
# Create mask for valid tokens
length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1)
@@ -325,10 +284,10 @@ def grpo_loss(
policy_ratio = mx.exp(token_log_probs - mx.stop_gradient(token_log_probs))
# Compute per-token loss following GRPO formula
- per_token_loss = -(policy_ratio * advantages.reshape(-1, 1) - beta * kl_div)
+ per_token_loss = -(policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask
# Average over tokens and sequences
- sequence_sums = (per_token_loss * length_mask).sum(axis=1)
+ sequence_sums = per_token_loss.sum(axis=1)
sequence_lengths = length_mask.sum(axis=1)
loss = (sequence_sums / sequence_lengths).mean()