diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py
index 4d79b1ac..edf4cf69 100644
--- a/llms/mlx_lm/lora.py
+++ b/llms/mlx_lm/lora.py
@@ -374,7 +374,9 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set
max_seq_length=args.max_seq_length,
beta=args.beta,
group_size=args.group_size,
- epsilon=args.epsilon
+ epsilon=args.epsilon,
+ temperature=args.temperature,
+ max_tokens=args.max_seq_length
)
test_ppl = math.exp(test_loss)
diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py
index ea59ed06..f215c0ed 100644
--- a/llms/mlx_lm/tuner/grpo_trainer.py
+++ b/llms/mlx_lm/tuner/grpo_trainer.py
@@ -74,15 +74,14 @@ def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kw
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
return [2.0 if r and a and r == a else 0.0 for r, a in zip(extracted_responses, answer)]
-
def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
if not completions:
return [0.0] * len(prompts)
- pattern = r".*?\s*.*?"
- matches = [bool(re.search(pattern, r)) if r else False for r in completions]
+ has_think = r".*"
+ has_answer = r".*"
+ matches = [(bool(re.search(has_think, r)) and bool(re.search(has_answer, r))) if r else False for r in completions]
return [0.5 if match else 0.0 for match in matches]
-
def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
if not completions:
return [0.0] * len(prompts)
@@ -114,44 +113,95 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li
return scores
-def generate_grpo(model: nn.Module, prompt, max_tokens, tokenizer, temperature):
- if len(prompt.shape) == 1:
- prompt = prompt[None, :]
- if prompt.shape[1] == 0:
+def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, temperature, group_size):
+ if len(prompts.shape) == 1:
+ prompts = prompts[None, :]
+ if prompts.shape[1] == 0:
return None
- end_sequence = tokenizer.encode("")
- end_sequence_length = len(end_sequence)
- initial_length = prompt.shape[1]
- output = mx.zeros((initial_length + max_tokens,), dtype=mx.int32)
- output[:initial_length] = prompt[0]
- current_length = initial_length
+
+ start_time = time.perf_counter()
+ tokens_generated = 0
+ batch_size = prompts.shape[0] * group_size
+
+ # Repeat each prompt group_size times
+ expanded_prompts = mx.repeat(prompts, group_size, axis=0)
+
+ end_sequence = mx.array(tokenizer.encode(""))
+ end_len = len(end_sequence)
+ initial_length = prompts.shape[1]
+
+ # Initialize output tensor for all sequences
+ output = mx.zeros((batch_size, initial_length + max_tokens), dtype=mx.int32)
+ output = mx.concatenate([expanded_prompts, mx.zeros((batch_size, max_tokens), dtype=mx.int32)], axis=1)
+ current_lengths = mx.array([initial_length] * batch_size)
+
+ temp_factor = 1/temperature if temperature > 0 else float('inf')
+
try:
- def sample(logits):
- if temperature > 0:
- logits /= temperature
- logprobs = logits - mx.logsumexp(logits, keepdims=True)
- return mx.random.categorical(logprobs[None, :]).astype(mx.int32)[0]
+ not_finished = mx.ones((batch_size,), dtype=mx.bool_)
+
for _ in range(max_tokens):
- current_input = output[:current_length][None, :]
- logits = model(current_input)
- token_logits = logits[0, -1]
- next_token = sample(token_logits)
- token_value = next_token.item()
- output[current_length] = token_value
- current_length += 1
- if token_value == tokenizer.eos_token_id:
+ # Check if all sequences are finished
+ if not mx.sum(not_finished).item():
break
- if current_length >= end_sequence_length:
- last_tokens = output[current_length - end_sequence_length:current_length].tolist()
- if last_tokens == end_sequence:
- break
- if current_length > initial_length:
- return output[:current_length]
+
+ # Get model outputs for all sequences
+ max_len = mx.max(current_lengths).item()
+ batch_inputs = output[:, :max_len]
+ logits = model(batch_inputs)[:, -1]
+
+ # Apply mask to logits
+ logits = logits * mx.expand_dims(not_finished, -1)
+
+ # Sample next tokens
+ logits *= temp_factor
+ logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
+ next_tokens = mx.random.categorical(logprobs)
+
+ # Update outputs for active sequences
+ for idx in range(batch_size):
+ if not_finished[idx].item():
+ curr_len = current_lengths[idx].item()
+ token_value = next_tokens[idx].item()
+
+ # Create new arrays with updates
+ output = mx.array(output.tolist()) # Make a copy
+ output[idx, curr_len] = token_value
+ current_lengths = mx.array([
+ l + 1 if i == idx else l
+ for i, l in enumerate(current_lengths.tolist())
+ ])
+ tokens_generated += 1
+
+ # Check end conditions
+ if token_value == tokenizer.eos_token_id:
+ not_finished = mx.array([
+ False if i == idx else nf
+ for i, nf in enumerate(not_finished.tolist())
+ ])
+ elif curr_len >= end_len:
+ last_tokens = output[idx, curr_len-end_len+1:curr_len+1]
+ if mx.array_equal(last_tokens, end_sequence):
+ not_finished = mx.array([
+ False if i == idx else nf
+ for i, nf in enumerate(not_finished.tolist())
+ ])
+
+ if _ % 32 == 0:
+ mx.eval(output, current_lengths, not_finished)
+
+ end_time = time.perf_counter()
+ generation_time = end_time - start_time
+ tokens_per_second = tokens_generated / generation_time
+ print(f"Generated {tokens_generated} tokens in {generation_time:.2f}s ({tokens_per_second:.2f} tokens/s)")
+
+ # Return only the valid part of each sequence
+ results = [output[i, :current_lengths[i].item()] for i in range(batch_size)]
+ return results
+
except Exception as e:
print(f"Generation error: {str(e)}")
return None
-
- return None
def get_per_token_logps(model: nn.Module, inputs, lengths):
@@ -185,7 +235,8 @@ def grpo_loss(
epsilon=1e-4,
max_tokens=64,
temperature=1.0,
- reward_weights=None
+ reward_weights=None,
+ is_validation=False
):
prompt_tokens, _, prompt_text, answer_text = batch
batch_size = len(prompt_tokens)
@@ -195,22 +246,27 @@ def grpo_loss(
for i in range(0, batch_size, batch_size):
batch_prompts = prompt_tokens[i:i+batch_size]
- for prompt in batch_prompts:
- prompt_tensor = mx.array(prompt)
- for _ in range(group_size):
- try:
- completion_ids = generate_grpo(model, prompt_tensor, max_tokens, tokenizer, temperature)
- if completion_ids is not None:
- completion_text = tokenizer.decode(completion_ids.tolist())
- all_completions.append(completion_ids)
- all_completion_texts.append(completion_text)
- mx.eval(completion_ids)
- del completion_ids
- except Exception as e:
- print(f"Generation error: {e}")
- continue
-
- mx.metal.clear_cache()
+ prompt_tensor = mx.stack([mx.array(p) for p in batch_prompts])
+
+ try:
+ completions = generate_grpo(
+ model,
+ prompt_tensor,
+ max_tokens,
+ tokenizer,
+ temperature,
+ group_size
+ )
+
+ if completions is not None:
+ for completion_ids in completions:
+ completion_text = tokenizer.decode(completion_ids.tolist())
+ all_completions.append(completion_ids)
+ all_completion_texts.append(completion_text)
+ mx.eval(completion_ids)
+ except Exception as e:
+ print(f"Generation error: {e}")
+ continue
expanded_answers = []
expanded_prompts = []
@@ -242,15 +298,12 @@ def grpo_loss(
token_log_probs = get_per_token_logps(model, inputs, lengths)
mx.eval(token_log_probs)
- mx.metal.clear_cache()
-
- # Reference policy probabilities
+
if ref_model is None:
ref_token_log_probs = token_log_probs
else:
ref_token_log_probs = get_per_token_logps(ref_model, inputs, lengths)
mx.eval(ref_token_log_probs)
- mx.metal.clear_cache()
max_len = max(x.shape[0] for x in token_log_probs)
padded_log_probs = []
@@ -339,6 +392,10 @@ def grpo_loss(
'kl': mean_kl,
**reward_metrics
}
+
+ if is_validation:
+ print(f"\nValidation sample generation:\n{all_completion_texts[-1]}\n")
+
mx.metal.clear_cache()
return loss, sequence_lengths.sum(), metrics
@@ -412,7 +469,7 @@ def evaluate_grpo(
):
all_losses = 0
ntokens = 0
- all_metrics = None
+ all_metrics = None
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
@@ -434,7 +491,8 @@ def evaluate_grpo(
epsilon=epsilon,
ref_model=ref_model,
temperature=temperature,
- max_tokens=max_tokens
+ max_tokens=max_tokens,
+ is_validation=True
)
all_losses += losses * toks