diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py
index 954eb81c..9d938df8 100644
--- a/llms/mlx_lm/tuner/grpo_trainer.py
+++ b/llms/mlx_lm/tuner/grpo_trainer.py
@@ -13,6 +13,7 @@ import numpy as np
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients
from ..utils import generate_step
+from ..models import cache
@dataclass
class GRPOTrainingArgs(TrainingArgs):
@@ -96,15 +97,13 @@ def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, *
if reason_content and answer_content:
scores.append(0.5)
continue
-
scores.append(0.0)
-
return scores
def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
if not completions:
return [0.0] * len(prompts)
- pattern = r"^\n.*?\n\n\n.*?\n\n$"
+ pattern = r"\n.*?\n\n*?"
matches = [bool(re.search(pattern, r)) if r else False for r in completions]
return [0.5 if match else 0.0 for match in matches]
@@ -120,28 +119,31 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li
count = 0.0
if text.count("\n") == 1:
count += 0.125
- if text.count("\n\n") == 1:
+ if text.count("") == 1:
count += 0.125
- if text.count("\n\n") == 1:
+ if text.count("") == 1:
count += 0.125
- if text.count("\n\n") == 1:
+ if text.count("") == 1:
count += 0.125
- end_text = text.split("\n\n")[-1]
+ end_text = text.split("")[-1]
count -= len(end_text) * 0.001 if len(end_text) > 0 else 0
scores.append(max(0.0, count))
return scores
-def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size, is_training=False):
+def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size, is_training=False, end_token: str = ""):
+ if model.training == False:
+ print("Model is in training mode", model.training, "Manually setting to eval mode")
+ model.train()
+
if len(prompts.shape) == 1:
prompts = prompts[None, :]
if prompts.shape[1] == 0:
return None
- model.eval()
batch_size = prompts.shape[0] * group_size
expanded_prompts = mx.repeat(prompts, group_size, axis=0)
- end_sequence = mx.array(tokenizer.encode(""))
+ end_sequence = mx.array(tokenizer.encode(end_token))
results = []
tokens_generated = 0
@@ -153,13 +155,15 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size,
if is_training:
current_input = expanded_prompts[idx]
+ prompt_cache = cache.make_prompt_cache(model)
+
+ # Initial forward pass with the prompt
+ logits = model(current_input[None], cache=prompt_cache)[:, -1]
+
while len(current_tokens) < max_tokens:
- logits = model(current_input[None])[:, -1]
probs = nn.softmax(logits, axis=-1)
next_token = mx.argmax(probs, axis=-1)
token = next_token.item()
- current_tokens.append(token)
- tokens_generated += 1
if token == tokenizer.eos_token_id:
break
if (len(current_tokens) >= len(end_sequence) and
@@ -168,10 +172,13 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size,
end_sequence
)):
break
- current_input = mx.concatenate([current_input, mx.array([token])])
- if len(current_tokens) % 32 == 0:
- mx.eval(current_input)
- mx.metal.clear_cache()
+
+ current_tokens.append(token)
+ tokens_generated += 1
+ current_input = mx.array([token])
+ logits = model(current_input[None], cache=prompt_cache)[:, -1]
+ mx.eval(current_input)
+ mx.metal.clear_cache()
else:
generator = generate_step(
expanded_prompts[idx],
@@ -180,10 +187,10 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size,
sampler=lambda x: mx.argmax(x, axis=-1)
)
for token, _ in generator:
- current_tokens.append(token)
- tokens_generated += 1
if token == tokenizer.eos_token_id:
break
+ current_tokens.append(token)
+ tokens_generated += 1
if current_tokens:
results.append(mx.array(current_tokens))
@@ -401,7 +408,7 @@ def grpo_loss(
if is_validation:
print(f"\nValidation sample generation:\n{all_completion_texts[-1]}\n")
-
+ print(f"Validation sample answer:\n{answer_text[-1]}\n")
mx.metal.clear_cache()
return loss, sequence_lengths.sum(), metrics