diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py
index 0c19031b..163b1be7 100644
--- a/llms/mlx_lm/tuner/datasets.py
+++ b/llms/mlx_lm/tuner/datasets.py
@@ -9,7 +9,7 @@ class GRPODataset:
"""
Dataset wrapper for GRPO training data.
Each example should have a 'prompt' and 'answer' field.
- Returns data in (prompt, answer) tuple format required by GRPO trainer.
+ Returns data in (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple format.
"""
def __init__(
self,
@@ -20,15 +20,14 @@ class GRPODataset:
):
self._data = []
for item in data:
- # Get prompt and answer text
- prompt = str(item[prompt_key])
- answer = str(item[answer_key])
-
- # Store as (prompt, answer) tuple
- self._data.append((prompt, answer))
+ prompt_str = str(item[prompt_key])
+ answer_str = str(item[answer_key])
+ prompt_tokens = tokenizer.encode(prompt_str)
+ answer_tokens = tokenizer.encode(answer_str)
+ self._data.append((prompt_tokens, answer_tokens, prompt_str, answer_str))
- def __getitem__(self, idx: int) -> Tuple[str, str]:
- """Returns a (prompt, answer) tuple for the given index."""
+ def __getitem__(self, idx: int) -> Tuple[List[int], List[int], str, str]:
+ """Returns a (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple."""
return self._data[idx]
def __len__(self) -> int:
diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py
index e997b504..16125611 100644
--- a/llms/mlx_lm/tuner/grpo_trainer.py
+++ b/llms/mlx_lm/tuner/grpo_trainer.py
@@ -12,7 +12,7 @@ from mlx.utils import tree_flatten
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients, iterate_batches
-from mlx_lm import generate
+from mlx_lm.utils import generate_step
@dataclass
@@ -35,6 +35,70 @@ class GRPOTrainingArgs(TrainingArgs):
)
+def generate_for_grpo(
+ model,
+ prompt,
+ max_tokens,
+ tokenizer,
+ temperature=1.0
+):
+ try:
+
+ # Ensure prompt is the right shape
+ if len(prompt.shape) == 1:
+ prompt = prompt[None, :]
+
+ # Initialize generation
+ generated = []
+ current_prompt = prompt[0]
+
+ for step in range(max_tokens):
+ try:
+ # Get model output with explicit shape checking
+ current_batch = current_prompt[None, :]
+
+ logits = model(current_batch)
+
+ # Ensure we have the last token logits
+ token_logits = logits[0, -1]
+
+ # Apply temperature and get probabilities
+ if temperature > 0:
+ token_logits = token_logits / temperature
+ probs = mx.softmax(token_logits)
+
+ # Sample the next token
+ next_token = mx.random.categorical(probs[None, :])
+ next_token = next_token[0]
+
+ # Force evaluation to catch any issues
+ mx.eval(next_token)
+ token_value = next_token.item()
+
+ # Add to generated sequence
+ generated.append(next_token)
+ current_prompt = mx.concatenate([current_prompt, next_token[None]])
+
+ if token_value == tokenizer.eos_token_id:
+ break
+
+ except Exception as e:
+ raise
+
+ if not generated:
+ return prompt[0]
+
+ try:
+ result = mx.concatenate([prompt[0], mx.stack(generated)])
+ mx.eval(result)
+ return result
+ except Exception as e:
+ raise
+
+ except Exception as e:
+ raise
+
+
def r1_extract_xml_answer(text: str) -> str:
"""Extracts the answer from an XML formatted text string."""
try:
@@ -45,42 +109,45 @@ def r1_extract_xml_answer(text: str) -> str:
print("[extract_xml_answer] Failed to extract answer from: ", text)
return ""
-def r1_accuracy_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
+def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
"""Calculates reward based on accuracy of extracted answers.
-
Args:
prompts: List of input prompts
completions: List of completion strings
answer: Expected answer or list of answers
**kwargs: Additional arguments
-
Returns:
list[float]: Reward values for each completion
"""
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
- q = prompts[0] if isinstance(prompts[0], str) else prompts[0][-1]['content']
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
-def r1_int_reward_func(completions, **kwargs) -> list[float]:
+def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
"""Rewards numerical responses.
-
Args:
+ prompts: List of input prompts
completions: List of completion strings
+ answer: Expected answer or list of answers
**kwargs: Additional arguments
-
Returns:
list[float]: Reward values for each completion
"""
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
-def r1_strict_format_reward_func(completions, **kwargs) -> list[float]:
+def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
+ """Rewards completions with flexible XML format."""
+ pattern = r".*?\s*.*?"
+ matches = [re.match(pattern, r) for r in completions]
+ return [0.5 if match else 0.0 for match in matches]
+
+def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
"""Rewards completions with strict XML format.
-
Args:
+ prompts: List of input prompts
completions: List of completion strings
+ answer: Expected answer or list of answers
**kwargs: Additional arguments
-
Returns:
list[float]: Reward values for each completion
"""
@@ -88,98 +155,128 @@ def r1_strict_format_reward_func(completions, **kwargs) -> list[float]:
matches = [re.match(pattern, r) for r in completions]
return [0.5 if match else 0.0 for match in matches]
-def r1_soft_format_reward_func(completions, **kwargs) -> list[float]:
- """Rewards completions with flexible XML format.
-
- Args:
- completions: List of completion strings
- **kwargs: Additional arguments
-
- Returns:
- list[float]: Reward values for each completion
- """
- pattern = r".*?\s*.*?"
- matches = [re.match(pattern, r) for r in completions]
- return [0.5 if match else 0.0 for match in matches]
-
-def r1_count_xml(text: str) -> float:
+def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
"""Calculates score based on XML formatting.
-
Args:
- text: Input text string
-
+ prompts: List of input prompts (unused)
+ completions: List of completion strings to evaluate
+ answer: Expected answer or list of answers (unused)
+ **kwargs: Additional arguments
Returns:
- float: Score based on XML tag presence and formatting
+ list[float]: List of scores based on XML tag presence and formatting
"""
- count = 0.0
- if text.count("\n") == 1:
- count += 0.125
- if text.count("\n\n") == 1:
- count += 0.125
- if text.count("\n\n") == 1:
- count += 0.125
+ scores = []
+ for text in completions:
+ count = 0.0
+ if text.count("\n") == 1:
+ count += 0.125
+ if text.count("\n\n") == 1:
+ count += 0.125
+ if text.count("\n\n") == 1:
+ count += 0.125
count -= len(text.split("\n\n")[-1])*0.001
- if text.count("\n") == 1:
- count += 0.125
+ if text.count("\n") == 1:
+ count += 0.125
count -= (len(text.split("\n")[-1]) - 1)*0.001
- return count
+ scores.append(count)
+ return scores
def grpo_loss(
- model,
- tokenizer,
- prompts,
- reward_funcs=None,
- beta=0.1,
- group_size=4,
- epsilon=1e-4,
- ref_model=None
- ):
- """
- Calculates the GRPO loss with support for multiple reward functions.
+ model,
+ tokenizer,
+ batch,
+ reward_funcs=None,
+ beta=0.1,
+ group_size=4,
+ epsilon=1e-4,
+ ref_model=None,
+ max_tokens=128,
+ temperature=1.0
+):
+ """Modified GRPO loss function with better error handling"""
+ prompt_tokens, answer_tokens, prompt_text, answer_text = batch
+ batch_size = len(prompt_tokens)
- Args:
- model: The model to optimize
- tokenizer: Tokenizer for processing inputs
- prompts: List of input prompts
- reward_funcs: List of reward functions to use
- beta: KL penalty coefficient
- group_size: Number of completions per prompt
- epsilon: Small constant for numerical stability
- ref_model: Optional reference model for KL divergence
-
- Returns:
- tuple: (loss, total_sequence_length, metrics_dict)
- """
- batch_size = len(prompts)
- # Generate multiple completions for each prompt
+ # Generate completions for each prompt
all_completions = []
+ all_completion_texts = []
- for prompt in prompts:
+ for prompt in prompt_tokens:
+ prompt_tensor = mx.array(prompt)
prompt_completions = []
+ prompt_completion_texts = []
+
+ # Generate group_size completions for each prompt
for _ in range(group_size):
- completion = generate(model, tokenizer, prompt)
- prompt_completions.append(completion)
+ try:
+ completion_ids = generate_for_grpo(
+ model,
+ prompt_tensor,
+ max_tokens,
+ tokenizer=tokenizer,
+ temperature=temperature
+ )
+
+ # Verify completion_ids is not None
+ if completion_ids is None:
+ print("Warning: generate_for_grpo returned None")
+ break
+
+ completion_text = tokenizer.decode(completion_ids.tolist())
+
+ prompt_completions.append(completion_ids)
+ prompt_completion_texts.append(completion_text)
+
+ except Exception as e:
+ print(f"Error in completion generation: {str(e)}")
+ # Fallback to using original prompt
+ prompt_completions.append(prompt_tensor)
+ prompt_completion_texts.append(tokenizer.decode(prompt_tensor.tolist()))
+
all_completions.extend(prompt_completions)
+ all_completion_texts.extend(prompt_completion_texts)
+
+ # Verify we have the expected number of completions
+ assert len(all_completions) == batch_size * group_size
+ assert len(all_completion_texts) == batch_size * group_size
- # Tokenize all prompts + completions
- tokenized_inputs = tokenizer(
- [p + c for p, c in zip(prompts * group_size, all_completions)],
- return_tensors="np",
- padding=True
- )
+ # Expand answer_text and prompt_text to match completion groups
+ expanded_answers = []
+ expanded_prompts = []
+ for i in range(batch_size):
+ expanded_answers.extend([answer_text[i]] * group_size)
+ expanded_prompts.extend([prompt_text[i]] * group_size)
+
+ # Verify we have the expected number of completions
+ assert len(all_completions) == batch_size * group_size
+ assert len(all_completion_texts) == batch_size * group_size
- inputs = mx.array(tokenized_inputs["input_ids"])
- attention_mask = mx.array(tokenized_inputs["attention_mask"])
+ max_length = max(ids.shape[0] for ids in all_completions)
+ padded_completions = []
+ attention_masks = []
- # Get lengths for proper masking
+ for completion_ids in all_completions:
+ padding_length = max_length - completion_ids.shape[0]
+ if padding_length > 0:
+ padding = mx.zeros((padding_length,), dtype=completion_ids.dtype)
+ padded_ids = mx.concatenate([completion_ids, padding])
+ mask = mx.concatenate([mx.ones_like(completion_ids), mx.zeros_like(padding)])
+ else:
+ padded_ids = completion_ids
+ mask = mx.ones_like(completion_ids)
+ padded_completions.append(padded_ids)
+ attention_masks.append(mask)
+
+ inputs = mx.stack(padded_completions)
+ attention_mask = mx.stack(attention_masks)
lengths = attention_mask.sum(axis=1)
# Get logits from current model
logits = model(inputs).astype(mx.float32)
# Calculate log probabilities
- log_probs = mx.log_softmax(logits[:, :-1, :], axis=-1)
+ log_probs = nn.log_softmax(logits[:, :-1, :], axis=-1)
# Prepare targets
targets = inputs[:, 1:]
@@ -197,7 +294,7 @@ def grpo_loss(
else:
ref_logits = model(inputs).astype(mx.float32)
- ref_log_probs = mx.log_softmax(ref_logits[:, :-1, :], axis=-1)
+ ref_log_probs = nn.log_softmax(ref_logits[:, :-1, :], axis=-1)
ref_token_log_probs = mx.take_along_axis(
ref_log_probs,
targets.reshape(*targets.shape, 1),
@@ -210,7 +307,11 @@ def grpo_loss(
# Calculate combined rewards from all reward functions
rewards = mx.zeros((len(all_completions),))
for reward_func in reward_funcs:
- func_rewards = mx.array(reward_func(all_completions))
+ func_rewards = mx.array(reward_func(
+ prompts=prompt_text,
+ completions=all_completion_texts,
+ answer=answer_text
+ ))
rewards += func_rewards
# Normalize rewards if using multiple reward functions
@@ -245,8 +346,12 @@ def grpo_loss(
# Collect metrics for each reward function separately
reward_metrics = {}
for i, reward_func in enumerate(reward_funcs):
- func_rewards = mx.array(reward_func(all_completions))
- func_grouped_rewards = func_rewards.reshape(batch_size, group_size)
+ func_rewards = mx.array(reward_func(
+ prompts=prompt_text,
+ completions=all_completion_texts,
+ answer=answer_text
+ ))
+ # func_grouped_rewards = func_rewards.reshape(batch_size, group_size)
reward_metrics[f'reward_func_{i}_mean'] = mx.mean(func_rewards)
reward_metrics[f'reward_func_{i}_std'] = mx.std(func_rewards)
@@ -264,26 +369,30 @@ def grpo_loss(
def iterate_grpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
"""
- Creates batches from prompt-answer pairs for GRPO training.
+ Creates batches from dataset entries for GRPO training.
Args:
- dataset: List of (prompt, answer) pairs
+ dataset: List of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples
tokenizer: Tokenizer for processing inputs
batch_size: Size of each batch
max_seq_length: Maximum sequence length
train: Whether this is for training
Yields:
- List of prompts for the current batch
+ Tuple containing:
+ - prompts_tokens: List of token sequences for current batch
+ - answers_tokens: List of token sequences
+ - prompts_text: List of prompt strings
+ - answers_text: List of answer strings
"""
- # Verify dataset is not empty and has correct format
- if not dataset or not isinstance(dataset[0], (tuple, list)) or len(dataset[0]) != 2:
- raise ValueError("Dataset must be a list of (prompt, answer) pairs")
-
- # Sort by combined length of prompt + answer
+ # Verify dataset format
+ if not dataset or not isinstance(dataset[0], tuple) or len(dataset[0]) != 4:
+ raise ValueError("Dataset must be list of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples")
+
+ # Sort by combined length of prompt + answer tokens
idx = sorted(range(len(dataset)),
key=lambda i: len(dataset[i][0]) + len(dataset[i][1]))
-
+
if len(dataset) < batch_size:
raise ValueError(
f"Dataset must have at least batch_size={batch_size} "
@@ -306,22 +415,22 @@ def iterate_grpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=F
indices = np.random.permutation(len(batch_idx)) if train else range(len(batch_idx))
for i in indices:
- # Get current batch of prompt-answer pairs
+ # Get current batch
current_batch = [dataset[j] for j in batch_idx[i]]
- # Extract prompts and answers
- prompts = [pair[0] for pair in current_batch]
- answers = [pair[1] for pair in current_batch]
-
- if any(len(p) > max_seq_length for p in prompts):
+ # Extract all components
+ prompts_tokens = [item[0] for item in current_batch]
+ answers_tokens = [item[1] for item in current_batch]
+ prompts_text = [item[2] for item in current_batch]
+ answers_text = [item[3] for item in current_batch]
+
+ if any(len(p) > max_seq_length for p in prompts_tokens):
print(
f"[WARNING] Some prompts are longer than {max_seq_length} tokens. "
"Long prompts will be truncated."
)
-
- # For GRPO, we only need to yield the prompts
- # The answers will be used by the reward functions
- yield prompts
+
+ yield prompts_tokens, answers_tokens, prompts_text, answers_text
if not train:
break
@@ -342,11 +451,19 @@ def evaluate_grpo(
loss: callable = grpo_loss,
iterate_batches: callable = iterate_grpo_batches
):
+ """
+ Evaluate model using GRPO loss.
+ Returns:
+ tuple: (average loss, number of tokens, average metrics)
+ """
all_losses = 0
ntokens = 0
-
+ all_metrics = None # Initialize metrics dictionary
+
+ # Create iterator for batches
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
-
+
+ # Iterate through batches
for _, batch in zip(
index_iterator,
iterate_batches(
@@ -356,35 +473,41 @@ def evaluate_grpo(
max_seq_length=max_seq_length,
),
):
- prompts = batch
+ # Calculate loss for current batch
losses, toks, metrics = loss(
model=model,
tokenizer=tokenizer,
- prompts=prompts,
+ batch=batch,
reward_funcs=reward_funcs,
beta=beta,
group_size=group_size,
epsilon=epsilon,
ref_model=ref_model
)
+
+ # Accumulate losses and tokens
all_losses += losses * toks
ntokens += toks
-
+
+ # Accumulate metrics
if all_metrics is None:
all_metrics = {k: v * toks for k, v in metrics.items()}
else:
for k, v in metrics.items():
all_metrics[k] += v * toks
-
+
+ # Evaluate accumulated values
mx.eval(all_losses, ntokens)
-
+
+ # Aggregate across distributed workers
all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu)
ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu)
all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()}
-
+
+ # Calculate averages
avg_metrics = {k: (v / ntokens).item() for k, v in all_metrics.items()}
avg_loss = (all_losses / ntokens).item()
-
+
return avg_loss, ntokens, avg_metrics
@@ -420,8 +543,18 @@ def train_grpo(
state = [model.state, optimizer.state]
def step(batch):
+
# Forward and backward pass
- (loss, toks, metrics), grad = loss_value_and_grad(model, *batch)
+ (loss, toks, metrics), grad = loss_value_and_grad(
+ model,
+ tokenizer=tokenizer,
+ batch=batch,
+ reward_funcs=reward_funcs,
+ beta=args.beta,
+ group_size=args.group_size,
+ epsilon=args.epsilon,
+ ref_model=ref_model
+ )
# All reduce the gradients if running in distributed mode
grad = average_gradients(grad)
@@ -430,7 +563,7 @@ def train_grpo(
optimizer.update(model, grad)
return loss, toks, metrics
-
+
loss_value_and_grad = nn.value_and_grad(model, loss)
losses = 0