diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py
index 3e581d13..12553b8a 100644
--- a/llms/mlx_lm/tuner/grpo_trainer.py
+++ b/llms/mlx_lm/tuner/grpo_trainer.py
@@ -76,10 +76,30 @@ def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kw
def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
if not completions:
return [0.0] * len(prompts)
- has_think = r".*"
- has_answer = r".*"
- matches = [(bool(re.search(has_think, r)) and bool(re.search(has_answer, r))) if r else False for r in completions]
- return [0.5 if match else 0.0 for match in matches]
+
+ scores = []
+ for completion in completions:
+ if not completion:
+ scores.append(0.0)
+ continue
+
+ reason_start = completion.find("")
+ reason_end = completion.find("")
+ answer_start = completion.find("")
+ answer_end = completion.find("")
+
+ if (reason_start != -1 and reason_end != -1 and
+ answer_start != -1 and answer_end != -1 and
+ reason_start < reason_end < answer_start < answer_end):
+ reason_content = completion[reason_start+13:reason_end].strip()
+ answer_content = completion[answer_start+8:answer_end].strip()
+ if reason_content and answer_content:
+ scores.append(0.5)
+ continue
+
+ scores.append(0.0)
+
+ return scores
def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
if not completions:
@@ -110,7 +130,7 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li
count -= len(end_text) * 0.001 if len(end_text) > 0 else 0
scores.append(max(0.0, count))
return scores
-
+
def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size, is_training=False):
if len(prompts.shape) == 1:
@@ -118,53 +138,49 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size,
if prompts.shape[1] == 0:
return None
+ model.eval()
batch_size = prompts.shape[0] * group_size
expanded_prompts = mx.repeat(prompts, group_size, axis=0)
- mx.eval(expanded_prompts)
+ end_sequence = mx.array(tokenizer.encode(""))
results = []
tokens_generated = 0
start_time = time.perf_counter()
- for idx in range(batch_size):
- current_prompt = expanded_prompts[idx:idx+1]
- mx.eval(current_prompt)
-
- current_tokens = []
- try:
+ try:
+ for idx in range(batch_size):
+ current_tokens = []
+
if is_training:
- # Initialize with prompt
- current_input = current_prompt[0]
- mx.eval(current_input)
-
+ current_input = expanded_prompts[idx]
while len(current_tokens) < max_tokens:
- # Generate one token at a time
- logits = model(current_input[None])
- next_token = mx.random.categorical(logits[:, -1, :])
+ logits = model(current_input[None])[:, -1]
+ next_token = mx.argmax(logits, axis=-1)
token = next_token.item()
current_tokens.append(token)
tokens_generated += 1
- # Clear intermediate results
- mx.eval(next_token)
- del logits
-
if token == tokenizer.eos_token_id:
break
- # Update input for next iteration
- current_input = mx.array([token])
- mx.eval(current_input)
+ if (len(current_tokens) >= len(end_sequence) and
+ mx.array_equal(
+ mx.array(current_tokens[-len(end_sequence):]),
+ end_sequence
+ )):
+ break
- # Clear cache periodically
- if len(current_tokens) % 8 == 0:
+ current_input = mx.concatenate([current_input, mx.array([token])])
+
+ if len(current_tokens) % 32 == 0:
+ mx.eval(current_input)
mx.metal.clear_cache()
else:
generator = generate_step(
- current_prompt[0],
+ expanded_prompts[idx],
model,
max_tokens=max_tokens,
- sampler=lambda x: mx.random.categorical(x)
+ sampler=lambda x: mx.argmax(x, axis=-1)
)
for token, _ in generator:
@@ -174,28 +190,18 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size,
break
if current_tokens:
- token_array = mx.array(current_tokens)
- mx.eval(token_array)
- results.append(token_array)
- del token_array
-
- except Exception as e:
- print(f"Generation failed for sequence {idx}: {e}")
- continue
+ results.append(mx.array(current_tokens))
+ mx.metal.clear_cache()
- mx.metal.clear_cache()
+ mx.eval(results)
+ generation_time = time.perf_counter() - start_time
+ print(f"Generated {tokens_generated} tokens in {generation_time:.2f}s ({tokens_generated/generation_time:.2f} tokens/s)")
+ return results
- if not results:
- print("No successful generations")
+ except Exception as e:
+ print(f"Generation error: {str(e)}")
return None
- mx.eval(results)
-
- generation_time = time.perf_counter() - start_time
- print(f"Generated {tokens_generated} tokens in {generation_time:.2f}s ({tokens_generated/generation_time:.2f} tokens/s)")
-
- return results
-
def get_per_token_logps(model: nn.Module, inputs, lengths):
logits = model(inputs).astype(mx.float16)
@@ -243,15 +249,23 @@ def grpo_loss(
prompt_tensor = mx.stack([mx.array(p) for p in batch_prompts])
try:
- completions = generate_grpo(
- model,
- prompt_tensor,
- max_tokens,
- tokenizer,
- group_size,
- True
- )
-
+ if is_validation:
+ completions = generate_grpo(
+ model,
+ prompt_tensor,
+ max_tokens,
+ tokenizer,
+ group_size
+ )
+ else:
+ completions = generate_grpo(
+ model,
+ prompt_tensor,
+ max_tokens,
+ tokenizer,
+ group_size,
+ is_training=True
+ )
if completions is not None:
for completion_ids in completions:
completion_text = tokenizer.decode(completion_ids.tolist())
@@ -261,8 +275,6 @@ def grpo_loss(
except Exception as e:
print(f"Generation error: {e}")
continue
-
- mx.metal.clear_cache()
expanded_answers = []
expanded_prompts = []