diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py
index 5f00d3e3..0a3e36c9 100644
--- a/llms/mlx_lm/tuner/datasets.py
+++ b/llms/mlx_lm/tuner/datasets.py
@@ -4,6 +4,7 @@ import types
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
+from .utils import GRPOExample
from transformers import PreTrainedTokenizer
@@ -11,7 +12,7 @@ class GRPODataset:
"""
Dataset wrapper for GRPO training data.
Each example should have a 'prompt' and 'answer' field.
- Returns data in (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple format.
+ Returns data as GRPOExample instances.
"""
def __init__(
self,
@@ -22,33 +23,40 @@ class GRPODataset:
use_chat_template: bool = False,
use_prompt: bool = False
):
- self._data = []
+ self._data: List[GRPOExample] = []
for item in data:
prompt_str = str(item[prompt_key])
answer_str = str(item[answer_key])
+
if use_chat_template:
prompt_tokens = tokenizer.apply_chat_template(
[
{'role': 'system', 'content': """A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
- The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer.
- The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here ."""},
- {'role': 'user', 'content': prompt_str}
+ The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer.
+ The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here ."""},
+ {'role': 'user', 'content': prompt_str}
],
)
answer_tokens = tokenizer.encode(answer_str)
else:
if use_prompt:
prompt_tokens = tokenizer.encode(f"""A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
- The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer.
- The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here .
- User: {prompt_str} Assistant: """)
+ The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer.
+ The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here .
+ User: {prompt_str} Assistant: """)
else:
prompt_tokens = tokenizer.encode(prompt_str)
answer_tokens = tokenizer.encode(answer_str)
- self._data.append((prompt_tokens, answer_tokens, prompt_str, answer_str))
+
+ self._data.append(GRPOExample(
+ prompt_tokens=prompt_tokens,
+ answer_tokens=answer_tokens,
+ prompt_text=prompt_str,
+ answer_text=answer_str
+ ))
- def __getitem__(self, idx: int) -> Tuple[List[int], List[int], str, str]:
- """Returns a (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple."""
+ def __getitem__(self, idx: int) -> GRPOExample:
+ """Returns a GRPOExample instance."""
return self._data[idx]
def __len__(self) -> int:
diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py
index 4a1e6bbf..13954665 100644
--- a/llms/mlx_lm/tuner/grpo_trainer.py
+++ b/llms/mlx_lm/tuner/grpo_trainer.py
@@ -2,6 +2,7 @@
import time
from dataclasses import dataclass, field
+from typing import List, Iterator, Optional
from pathlib import Path
import re
@@ -10,6 +11,7 @@ import mlx.nn as nn
import numpy as np
from mlx.utils import tree_flatten
+from .utils import GRPOBatch, GRPOExample
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients, iterate_batches
@dataclass
@@ -109,7 +111,7 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li
return scores
-def generate_grpo(model, prompt, max_tokens, tokenizer, temperature):
+def generate_grpo(model: nn.Module, prompt, max_tokens, tokenizer, temperature):
if len(prompt.shape) == 1:
prompt = prompt[None, :]
if prompt.shape[1] == 0:
@@ -117,9 +119,11 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature):
end_sequence = tokenizer.encode("")
end_sequence_length = len(end_sequence)
- output = mx.zeros((prompt.shape[1] + max_tokens,), dtype=mx.int32)
- output[:prompt.shape[1]] = prompt[0]
- current_length = prompt.shape[1]
+
+ initial_length = prompt.shape[1]
+ output = mx.zeros((initial_length + max_tokens,), dtype=mx.int32)
+ output[:initial_length] = prompt[0]
+ current_length = initial_length
try:
def sample(logits):
@@ -145,7 +149,7 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature):
if last_tokens == end_sequence:
break
- if current_length > prompt.shape[1]:
+ if current_length > initial_length:
return output[:current_length]
except Exception as e:
@@ -155,7 +159,7 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature):
return None
-def get_per_token_logps(model, inputs, lengths):
+def get_per_token_logps(model: nn.Module, inputs, lengths):
logits = model(inputs).astype(mx.float16)
logits = logits[:, :-1, :]
targets = inputs[:, 1:]
@@ -176,10 +180,10 @@ def get_per_token_logps(model, inputs, lengths):
def grpo_loss(
- model,
- ref_model,
+ model: nn.Module,
+ ref_model: Optional[nn.Module],
tokenizer,
- batch,
+ batch=GRPOBatch,
reward_funcs=None,
beta=0.1,
group_size=4,
@@ -187,15 +191,18 @@ def grpo_loss(
max_tokens=64,
temperature=1.0
):
- prompt_tokens, answer_tokens, prompt_text, answer_text = batch
- batch_size = len(prompt_tokens)
+ prompts_tokens = batch.prompt_tokens
+ answers_tokens = batch.answer_tokens
+ prompts_text = batch.prompt_texts
+ answers_text = batch.answer_texts
+ batch_size = len(prompts_tokens)
# Generation logic remains the same
all_completions = []
all_completion_texts = []
for i in range(0, batch_size, batch_size):
- batch_prompts = prompt_tokens[i:i+batch_size]
+ batch_prompts = prompts_tokens[i:i+batch_size]
for prompt in batch_prompts:
prompt_tensor = mx.array(prompt)
for _ in range(group_size):
@@ -219,8 +226,8 @@ def grpo_loss(
expanded_answers = []
expanded_prompts = []
for i in range(batch_size):
- expanded_answers.extend([answer_text[i]] * group_size)
- expanded_prompts.extend([prompt_text[i]] * group_size)
+ expanded_answers.extend([answers_text[i]] * group_size)
+ expanded_prompts.extend([prompts_text[i]] * group_size)
max_length = max(ids.shape[0] for ids in all_completions)
padded_completions = []
@@ -332,60 +339,66 @@ def grpo_loss(
return loss, sequence_lengths.sum(), metrics
-def iterate_grpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
- if not dataset or not isinstance(dataset[0], tuple) or len(dataset[0]) != 4:
- raise ValueError("Dataset must be list of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples")
-
- # Sort by length but use generator to avoid keeping full sorted list in memory
- def length_key(i):
- return len(dataset[i][0]) + len(dataset[i][1])
-
- idx = sorted(range(len(dataset)), key=length_key)
-
+def iterate_grpo_batches(
+ dataset: List[GRPOExample],
+ batch_size: int,
+ max_seq_length: int,
+ train: bool = False,
+) -> Iterator[GRPOBatch]:
if len(dataset) < batch_size:
raise ValueError(
f"Dataset must have at least batch_size={batch_size} "
f"examples but only has {len(dataset)}."
)
+ # Get MLX distributed setup
step = mx.distributed.init().size()
if batch_size % step != 0:
raise ValueError("The batch size must be divisible by the number of workers")
- # Use generator for batch indices
- def batch_index_generator():
- for i in range(0, len(idx) - batch_size + 1, batch_size):
- yield idx[i : i + batch_size : step]
-
+ # Sort by combined length for efficient batching
+ def length_key(example: GRPOExample) -> int:
+ return len(example.prompt_tokens) + len(example.answer_tokens)
+
+ sorted_dataset = sorted(dataset, key=length_key)
+
+ # Create batch indices
+ num_complete_batches = (len(dataset) - batch_size + 1) // batch_size
+ batch_starts = range(0, num_complete_batches * batch_size, batch_size)
+
while True:
- indices = (
- np.random.permutation(list(batch_index_generator())) if train
- else batch_index_generator()
- )
+ # Shuffle batch start indices
+ shuffled_starts = np.random.permutation(batch_starts)
- for batch_idx in indices:
- current_batch = [dataset[j] for j in batch_idx]
+ for start_idx in shuffled_starts:
+ # Account for distributed setup by taking every step-th example
+ batch_idx = list(range(start_idx, start_idx + batch_size, step))
+ current_batch = [sorted_dataset[j] for j in batch_idx]
- prompts_tokens = [item[0] for item in current_batch]
- answers_tokens = [item[1] for item in current_batch]
- prompts_text = [item[2] for item in current_batch]
- answers_text = [item[3] for item in current_batch]
-
- if any(len(p) > max_seq_length for p in prompts_tokens):
+ # Create batch using dataclass attributes
+ batch = GRPOBatch(
+ prompt_tokens=[ex.prompt_tokens for ex in current_batch],
+ answer_tokens=[ex.answer_tokens for ex in current_batch],
+ prompt_texts=[ex.prompt_text for ex in current_batch],
+ answer_texts=[ex.answer_text for ex in current_batch]
+ )
+
+ # Check sequence lengths
+ if any(len(tokens) > max_seq_length for tokens in batch.prompt_tokens):
print(
f"[WARNING] Some prompts are longer than {max_seq_length} tokens. "
"Long prompts will be truncated."
)
-
- yield prompts_tokens, answers_tokens, prompts_text, answers_text
-
+
+ yield batch
+
if not train:
break
def evaluate_grpo(
- model,
- ref_model,
+ model: nn.Module,
+ ref_model: Optional[nn.Module],
dataset,
tokenizer,
batch_size,
@@ -415,7 +428,6 @@ def evaluate_grpo(
index_iterator,
iterate_batches(
dataset=dataset,
- tokenizer=tokenizer,
batch_size=batch_size,
max_seq_length=max_seq_length,
),
@@ -459,12 +471,12 @@ def evaluate_grpo(
def train_grpo(
- model,
- ref_model,
+ model: nn.Module,
+ ref_model: Optional[nn.Module],
tokenizer,
optimizer,
- train_dataset,
- val_dataset,
+ train_dataset: List[GRPOExample],
+ val_dataset: List[GRPOExample],
reward_funcs = [
r1_accuracy_reward_func,
r1_int_reward_func,
@@ -535,7 +547,6 @@ def train_grpo(
range(1, args.iters + 1),
iterate_batches(
dataset=train_dataset,
- tokenizer=tokenizer,
batch_size=args.batch_size,
max_seq_length=args.max_seq_length,
train=True,
diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py
index d86e01dd..d3497177 100644
--- a/llms/mlx_lm/tuner/utils.py
+++ b/llms/mlx_lm/tuner/utils.py
@@ -2,7 +2,8 @@
import json
import types
from pathlib import Path
-from typing import Dict
+from typing import Dict, List
+from dataclasses import dataclass
import mlx.core as mx
import mlx.nn as nn
@@ -275,3 +276,19 @@ def print_trainable_parameters(model):
f"Trainable parameters: {(trainable_p * 100 / total_p):.3f}% "
f"({trainable_p:.3f}M/{total_p:.3f}M)"
)
+
+@dataclass
+class GRPOExample:
+ """Single example for GRPO training/inference."""
+ prompt_tokens: List[int]
+ answer_tokens: List[int]
+ prompt_text: str
+ answer_text: str
+
+@dataclass
+class GRPOBatch:
+ """A batch of GRPO examples."""
+ prompt_tokens: List[List[int]]
+ answer_tokens: List[List[int]]
+ prompt_texts: List[str]
+ answer_texts: List[str]
\ No newline at end of file