diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py
index d82fa0ff..fb19ba50 100644
--- a/llms/mlx_lm/tuner/datasets.py
+++ b/llms/mlx_lm/tuner/datasets.py
@@ -2,9 +2,8 @@ import itertools
import json
import types
from pathlib import Path
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Dict, List, Tuple
-from .utils import GRPOExample
from transformers import PreTrainedTokenizer
@@ -12,7 +11,7 @@ class GRPODataset:
"""
Dataset wrapper for GRPO training data.
Each example should have a 'prompt' and 'answer' field.
- Returns data as GRPOExample instances.
+ Returns data in (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple format.
"""
def __init__(
self,
@@ -23,40 +22,33 @@ class GRPODataset:
use_chat_template: bool = False,
use_prompt: bool = False
):
- self._data: List[GRPOExample] = []
+ self._data = []
for item in data:
prompt_str = str(item[prompt_key])
answer_str = str(item[answer_key])
-
if use_chat_template:
prompt_tokens = tokenizer.apply_chat_template(
[
{'role': 'system', 'content': """A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
- The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer.
- The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here ."""},
- {'role': 'user', 'content': prompt_str}
+ The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer.
+ The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here ."""},
+ {'role': 'user', 'content': prompt_str}
],
)
answer_tokens = tokenizer.encode(answer_str)
else:
if use_prompt:
prompt_tokens = tokenizer.encode(f"""A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
- The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer.
- The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here .
- User: {prompt_str} Assistant: """)
+ The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer.
+ The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here .
+ User: {prompt_str} Assistant: """)
else:
prompt_tokens = tokenizer.encode(prompt_str)
answer_tokens = tokenizer.encode(answer_str)
-
- self._data.append(GRPOExample(
- prompt_tokens=prompt_tokens,
- answer_tokens=answer_tokens,
- prompt_text=prompt_str,
- answer_text=answer_str
- ))
+ self._data.append((prompt_tokens, answer_tokens, prompt_str, answer_str))
- def __getitem__(self, idx: int) -> GRPOExample:
- """Returns a GRPOExample instance."""
+ def __getitem__(self, idx: int) -> Tuple[List[int], List[int], str, str]:
+ """Returns a (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple."""
return self._data[idx]
def __len__(self) -> int:
@@ -318,7 +310,7 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer):
train, valid, test = load_local_dataset(args, data_path, tokenizer, args)
else:
print(f"Loading Hugging Face dataset {args.data}.")
- train, valid, test = load_hf_dataset(args, args.data, tokenizer, args)
+ train, valid, test = load_hf_dataset(args.data, tokenizer, args)
if args.train and len(train) == 0:
raise ValueError(
diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py
index 13954665..d0fa5fae 100644
--- a/llms/mlx_lm/tuner/grpo_trainer.py
+++ b/llms/mlx_lm/tuner/grpo_trainer.py
@@ -1,18 +1,18 @@
# Copyright © 2024 Apple Inc.
-import time
+from typing import List, Optional, Callable
from dataclasses import dataclass, field
-from typing import List, Iterator, Optional
from pathlib import Path
+import time
import re
+from mlx.utils import tree_flatten
import mlx.core as mx
import mlx.nn as nn
import numpy as np
-from mlx.utils import tree_flatten
-from .utils import GRPOBatch, GRPOExample
-from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients, iterate_batches
+from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients
+
@dataclass
class GRPOTrainingArgs(TrainingArgs):
@@ -37,6 +37,9 @@ class GRPOTrainingArgs(TrainingArgs):
)
+RewardFunctions = Callable[[List[str], List[str], List[str]], List[float]]
+
+
def r1_extract_xml_answer(text: str) -> str:
"""Extracts the answer from an XML formatted text string."""
try:
@@ -180,10 +183,10 @@ def get_per_token_logps(model: nn.Module, inputs, lengths):
def grpo_loss(
- model: nn.Module,
- ref_model: Optional[nn.Module],
+ model,
+ ref_model,
tokenizer,
- batch=GRPOBatch,
+ batch,
reward_funcs=None,
beta=0.1,
group_size=4,
@@ -191,18 +194,14 @@ def grpo_loss(
max_tokens=64,
temperature=1.0
):
- prompts_tokens = batch.prompt_tokens
- answers_tokens = batch.answer_tokens
- prompts_text = batch.prompt_texts
- answers_text = batch.answer_texts
- batch_size = len(prompts_tokens)
-
- # Generation logic remains the same
+ prompt_tokens, answer_tokens, prompt_text, answer_text = batch
+ batch_size = len(prompt_tokens)
+
all_completions = []
all_completion_texts = []
for i in range(0, batch_size, batch_size):
- batch_prompts = prompts_tokens[i:i+batch_size]
+ batch_prompts = prompt_tokens[i:i+batch_size]
for prompt in batch_prompts:
prompt_tensor = mx.array(prompt)
for _ in range(group_size):
@@ -212,8 +211,6 @@ def grpo_loss(
completion_text = tokenizer.decode(completion_ids.tolist())
all_completions.append(completion_ids)
all_completion_texts.append(completion_text)
-
- # Clear completion tensors
mx.eval(completion_ids)
del completion_ids
except Exception as e:
@@ -222,12 +219,11 @@ def grpo_loss(
mx.metal.clear_cache()
- # Prepare inputs
expanded_answers = []
expanded_prompts = []
for i in range(batch_size):
- expanded_answers.extend([answers_text[i]] * group_size)
- expanded_prompts.extend([prompts_text[i]] * group_size)
+ expanded_answers.extend([answer_text[i]] * group_size)
+ expanded_prompts.extend([prompt_text[i]] * group_size)
max_length = max(ids.shape[0] for ids in all_completions)
padded_completions = []
@@ -260,6 +256,8 @@ def grpo_loss(
ref_token_log_probs = token_log_probs
else:
ref_token_log_probs = get_per_token_logps(ref_model, inputs, lengths)
+ mx.eval(ref_token_log_probs)
+ mx.metal.clear_cache()
max_len = max(x.shape[0] for x in token_log_probs)
padded_log_probs = []
@@ -275,7 +273,7 @@ def grpo_loss(
token_log_probs = mx.stack(padded_log_probs)
ref_token_log_probs = mx.stack(padded_ref_log_probs)
- # Calculate rewards and advantages
+ # Rewards and advantages
rewards = mx.zeros((len(all_completions),))
for reward_func in reward_funcs:
func_rewards = mx.array(reward_func(
@@ -288,7 +286,7 @@ def grpo_loss(
if len(reward_funcs) > 1:
rewards /= len(reward_funcs)
- # Reshape rewards and compute advantages following GRPO formula
+ # Reshape rewards and compute advantages
rewards_reshaped = rewards.reshape(batch_size, group_size)
mean_rewards = mx.broadcast_to(mx.mean(rewards_reshaped, axis=1)[:, None], (rewards_reshaped.shape[0], group_size)).reshape(-1)
std_rewards = mx.broadcast_to(mx.std(rewards_reshaped, axis=1)[:, None], (rewards_reshaped.shape[0], group_size)).reshape(-1)
@@ -303,7 +301,7 @@ def grpo_loss(
# Compute policy ratio
policy_ratio = mx.exp(mx.array(token_log_probs - mx.stop_gradient(ref_token_log_probs)))
- # Compute per-token loss following GRPO formula
+ # Compute per-token loss
per_token_loss = -((policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask)
# Average over tokens
@@ -339,59 +337,51 @@ def grpo_loss(
return loss, sequence_lengths.sum(), metrics
-def iterate_grpo_batches(
- dataset: List[GRPOExample],
- batch_size: int,
- max_seq_length: int,
- train: bool = False,
-) -> Iterator[GRPOBatch]:
+def iterate_grpo_batches(dataset, batch_size, max_seq_length, train=False):
+ if not dataset or not isinstance(dataset[0], tuple) or len(dataset[0]) != 4:
+ raise ValueError("Dataset must be list of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples")
+
+ def length_key(i):
+ return len(dataset[i][0]) + len(dataset[i][1])
+
+ idx = sorted(range(len(dataset)), key=length_key)
+
if len(dataset) < batch_size:
raise ValueError(
f"Dataset must have at least batch_size={batch_size} "
f"examples but only has {len(dataset)}."
)
- # Get MLX distributed setup
step = mx.distributed.init().size()
if batch_size % step != 0:
raise ValueError("The batch size must be divisible by the number of workers")
- # Sort by combined length for efficient batching
- def length_key(example: GRPOExample) -> int:
- return len(example.prompt_tokens) + len(example.answer_tokens)
-
- sorted_dataset = sorted(dataset, key=length_key)
-
- # Create batch indices
- num_complete_batches = (len(dataset) - batch_size + 1) // batch_size
- batch_starts = range(0, num_complete_batches * batch_size, batch_size)
-
+ def batch_index_generator():
+ for i in range(0, len(idx) - batch_size + 1, batch_size):
+ yield idx[i : i + batch_size : step]
+
while True:
- # Shuffle batch start indices
- shuffled_starts = np.random.permutation(batch_starts)
+ indices = (
+ np.random.permutation(list(batch_index_generator())) if train
+ else batch_index_generator()
+ )
- for start_idx in shuffled_starts:
- # Account for distributed setup by taking every step-th example
- batch_idx = list(range(start_idx, start_idx + batch_size, step))
- current_batch = [sorted_dataset[j] for j in batch_idx]
+ for batch_idx in indices:
+ current_batch = [dataset[j] for j in batch_idx]
- # Create batch using dataclass attributes
- batch = GRPOBatch(
- prompt_tokens=[ex.prompt_tokens for ex in current_batch],
- answer_tokens=[ex.answer_tokens for ex in current_batch],
- prompt_texts=[ex.prompt_text for ex in current_batch],
- answer_texts=[ex.answer_text for ex in current_batch]
- )
-
- # Check sequence lengths
- if any(len(tokens) > max_seq_length for tokens in batch.prompt_tokens):
+ prompts_tokens = [item[0] for item in current_batch]
+ answers_tokens = [item[1] for item in current_batch]
+ prompts_text = [item[2] for item in current_batch]
+ answers_text = [item[3] for item in current_batch]
+
+ if any(len(p) > max_seq_length for p in prompts_tokens):
print(
f"[WARNING] Some prompts are longer than {max_seq_length} tokens. "
"Long prompts will be truncated."
)
-
- yield batch
-
+
+ yield prompts_tokens, answers_tokens, prompts_text, answers_text
+
if not train:
break
@@ -407,7 +397,7 @@ def evaluate_grpo(
epsilon: float,
group_size: int,
max_seq_length,
- reward_funcs = None,
+ reward_funcs: Optional[List[RewardFunctions]] = None,
loss_fn: callable = grpo_loss,
iterate_batches: callable = iterate_grpo_batches
):
@@ -418,12 +408,10 @@ def evaluate_grpo(
"""
all_losses = 0
ntokens = 0
- all_metrics = None # Initialize metrics dictionary
+ all_metrics = None
- # Create iterator for batches
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
- # Iterate through batches
for _, batch in zip(
index_iterator,
iterate_batches(
@@ -432,7 +420,6 @@ def evaluate_grpo(
max_seq_length=max_seq_length,
),
):
- # Calculate loss for current batch
losses, toks, metrics = loss_fn(
model=model,
tokenizer=tokenizer,
@@ -444,18 +431,15 @@ def evaluate_grpo(
ref_model=ref_model
)
- # Accumulate losses and tokens
all_losses += losses * toks
ntokens += toks
- # Accumulate metrics
if all_metrics is None:
all_metrics = {k: v * toks for k, v in metrics.items()}
else:
for k, v in metrics.items():
all_metrics[k] += v * toks
- # Evaluate accumulated values
mx.eval(all_losses, ntokens)
# Aggregate across distributed workers
@@ -475,9 +459,9 @@ def train_grpo(
ref_model: Optional[nn.Module],
tokenizer,
optimizer,
- train_dataset: List[GRPOExample],
- val_dataset: List[GRPOExample],
- reward_funcs = [
+ train_dataset,
+ val_dataset,
+ reward_funcs: Optional[List[RewardFunctions]] = [
r1_accuracy_reward_func,
r1_int_reward_func,
r1_strict_format_reward_func,
diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py
index d3497177..7586fda4 100644
--- a/llms/mlx_lm/tuner/utils.py
+++ b/llms/mlx_lm/tuner/utils.py
@@ -2,8 +2,7 @@
import json
import types
from pathlib import Path
-from typing import Dict, List
-from dataclasses import dataclass
+from typing import Dict
import mlx.core as mx
import mlx.nn as nn
@@ -275,20 +274,4 @@ def print_trainable_parameters(model):
print(
f"Trainable parameters: {(trainable_p * 100 / total_p):.3f}% "
f"({trainable_p:.3f}M/{total_p:.3f}M)"
- )
-
-@dataclass
-class GRPOExample:
- """Single example for GRPO training/inference."""
- prompt_tokens: List[int]
- answer_tokens: List[int]
- prompt_text: str
- answer_text: str
-
-@dataclass
-class GRPOBatch:
- """A batch of GRPO examples."""
- prompt_tokens: List[List[int]]
- answer_tokens: List[List[int]]
- prompt_texts: List[str]
- answer_texts: List[str]
\ No newline at end of file
+ )
\ No newline at end of file