diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 5f00d3e3..0a3e36c9 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -4,6 +4,7 @@ import types from pathlib import Path from typing import Any, Dict, List, Optional, Tuple +from .utils import GRPOExample from transformers import PreTrainedTokenizer @@ -11,7 +12,7 @@ class GRPODataset: """ Dataset wrapper for GRPO training data. Each example should have a 'prompt' and 'answer' field. - Returns data in (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple format. + Returns data as GRPOExample instances. """ def __init__( self, @@ -22,33 +23,40 @@ class GRPODataset: use_chat_template: bool = False, use_prompt: bool = False ): - self._data = [] + self._data: List[GRPOExample] = [] for item in data: prompt_str = str(item[prompt_key]) answer_str = str(item[answer_key]) + if use_chat_template: prompt_tokens = tokenizer.apply_chat_template( [ {'role': 'system', 'content': """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. - The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer. - The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here ."""}, - {'role': 'user', 'content': prompt_str} + The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer. + The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here ."""}, + {'role': 'user', 'content': prompt_str} ], ) answer_tokens = tokenizer.encode(answer_str) else: if use_prompt: prompt_tokens = tokenizer.encode(f"""A conversation between User and Assistant. The user asks a question, and the Assistant solves it. - The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer. - The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . - User: {prompt_str} Assistant: """) + The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer. + The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . + User: {prompt_str} Assistant: """) else: prompt_tokens = tokenizer.encode(prompt_str) answer_tokens = tokenizer.encode(answer_str) - self._data.append((prompt_tokens, answer_tokens, prompt_str, answer_str)) + + self._data.append(GRPOExample( + prompt_tokens=prompt_tokens, + answer_tokens=answer_tokens, + prompt_text=prompt_str, + answer_text=answer_str + )) - def __getitem__(self, idx: int) -> Tuple[List[int], List[int], str, str]: - """Returns a (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple.""" + def __getitem__(self, idx: int) -> GRPOExample: + """Returns a GRPOExample instance.""" return self._data[idx] def __len__(self) -> int: diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index 4a1e6bbf..13954665 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -2,6 +2,7 @@ import time from dataclasses import dataclass, field +from typing import List, Iterator, Optional from pathlib import Path import re @@ -10,6 +11,7 @@ import mlx.nn as nn import numpy as np from mlx.utils import tree_flatten +from .utils import GRPOBatch, GRPOExample from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients, iterate_batches @dataclass @@ -109,7 +111,7 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li return scores -def generate_grpo(model, prompt, max_tokens, tokenizer, temperature): +def generate_grpo(model: nn.Module, prompt, max_tokens, tokenizer, temperature): if len(prompt.shape) == 1: prompt = prompt[None, :] if prompt.shape[1] == 0: @@ -117,9 +119,11 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature): end_sequence = tokenizer.encode("") end_sequence_length = len(end_sequence) - output = mx.zeros((prompt.shape[1] + max_tokens,), dtype=mx.int32) - output[:prompt.shape[1]] = prompt[0] - current_length = prompt.shape[1] + + initial_length = prompt.shape[1] + output = mx.zeros((initial_length + max_tokens,), dtype=mx.int32) + output[:initial_length] = prompt[0] + current_length = initial_length try: def sample(logits): @@ -145,7 +149,7 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature): if last_tokens == end_sequence: break - if current_length > prompt.shape[1]: + if current_length > initial_length: return output[:current_length] except Exception as e: @@ -155,7 +159,7 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature): return None -def get_per_token_logps(model, inputs, lengths): +def get_per_token_logps(model: nn.Module, inputs, lengths): logits = model(inputs).astype(mx.float16) logits = logits[:, :-1, :] targets = inputs[:, 1:] @@ -176,10 +180,10 @@ def get_per_token_logps(model, inputs, lengths): def grpo_loss( - model, - ref_model, + model: nn.Module, + ref_model: Optional[nn.Module], tokenizer, - batch, + batch=GRPOBatch, reward_funcs=None, beta=0.1, group_size=4, @@ -187,15 +191,18 @@ def grpo_loss( max_tokens=64, temperature=1.0 ): - prompt_tokens, answer_tokens, prompt_text, answer_text = batch - batch_size = len(prompt_tokens) + prompts_tokens = batch.prompt_tokens + answers_tokens = batch.answer_tokens + prompts_text = batch.prompt_texts + answers_text = batch.answer_texts + batch_size = len(prompts_tokens) # Generation logic remains the same all_completions = [] all_completion_texts = [] for i in range(0, batch_size, batch_size): - batch_prompts = prompt_tokens[i:i+batch_size] + batch_prompts = prompts_tokens[i:i+batch_size] for prompt in batch_prompts: prompt_tensor = mx.array(prompt) for _ in range(group_size): @@ -219,8 +226,8 @@ def grpo_loss( expanded_answers = [] expanded_prompts = [] for i in range(batch_size): - expanded_answers.extend([answer_text[i]] * group_size) - expanded_prompts.extend([prompt_text[i]] * group_size) + expanded_answers.extend([answers_text[i]] * group_size) + expanded_prompts.extend([prompts_text[i]] * group_size) max_length = max(ids.shape[0] for ids in all_completions) padded_completions = [] @@ -332,60 +339,66 @@ def grpo_loss( return loss, sequence_lengths.sum(), metrics -def iterate_grpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): - if not dataset or not isinstance(dataset[0], tuple) or len(dataset[0]) != 4: - raise ValueError("Dataset must be list of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples") - - # Sort by length but use generator to avoid keeping full sorted list in memory - def length_key(i): - return len(dataset[i][0]) + len(dataset[i][1]) - - idx = sorted(range(len(dataset)), key=length_key) - +def iterate_grpo_batches( + dataset: List[GRPOExample], + batch_size: int, + max_seq_length: int, + train: bool = False, +) -> Iterator[GRPOBatch]: if len(dataset) < batch_size: raise ValueError( f"Dataset must have at least batch_size={batch_size} " f"examples but only has {len(dataset)}." ) + # Get MLX distributed setup step = mx.distributed.init().size() if batch_size % step != 0: raise ValueError("The batch size must be divisible by the number of workers") - # Use generator for batch indices - def batch_index_generator(): - for i in range(0, len(idx) - batch_size + 1, batch_size): - yield idx[i : i + batch_size : step] - + # Sort by combined length for efficient batching + def length_key(example: GRPOExample) -> int: + return len(example.prompt_tokens) + len(example.answer_tokens) + + sorted_dataset = sorted(dataset, key=length_key) + + # Create batch indices + num_complete_batches = (len(dataset) - batch_size + 1) // batch_size + batch_starts = range(0, num_complete_batches * batch_size, batch_size) + while True: - indices = ( - np.random.permutation(list(batch_index_generator())) if train - else batch_index_generator() - ) + # Shuffle batch start indices + shuffled_starts = np.random.permutation(batch_starts) - for batch_idx in indices: - current_batch = [dataset[j] for j in batch_idx] + for start_idx in shuffled_starts: + # Account for distributed setup by taking every step-th example + batch_idx = list(range(start_idx, start_idx + batch_size, step)) + current_batch = [sorted_dataset[j] for j in batch_idx] - prompts_tokens = [item[0] for item in current_batch] - answers_tokens = [item[1] for item in current_batch] - prompts_text = [item[2] for item in current_batch] - answers_text = [item[3] for item in current_batch] - - if any(len(p) > max_seq_length for p in prompts_tokens): + # Create batch using dataclass attributes + batch = GRPOBatch( + prompt_tokens=[ex.prompt_tokens for ex in current_batch], + answer_tokens=[ex.answer_tokens for ex in current_batch], + prompt_texts=[ex.prompt_text for ex in current_batch], + answer_texts=[ex.answer_text for ex in current_batch] + ) + + # Check sequence lengths + if any(len(tokens) > max_seq_length for tokens in batch.prompt_tokens): print( f"[WARNING] Some prompts are longer than {max_seq_length} tokens. " "Long prompts will be truncated." ) - - yield prompts_tokens, answers_tokens, prompts_text, answers_text - + + yield batch + if not train: break def evaluate_grpo( - model, - ref_model, + model: nn.Module, + ref_model: Optional[nn.Module], dataset, tokenizer, batch_size, @@ -415,7 +428,6 @@ def evaluate_grpo( index_iterator, iterate_batches( dataset=dataset, - tokenizer=tokenizer, batch_size=batch_size, max_seq_length=max_seq_length, ), @@ -459,12 +471,12 @@ def evaluate_grpo( def train_grpo( - model, - ref_model, + model: nn.Module, + ref_model: Optional[nn.Module], tokenizer, optimizer, - train_dataset, - val_dataset, + train_dataset: List[GRPOExample], + val_dataset: List[GRPOExample], reward_funcs = [ r1_accuracy_reward_func, r1_int_reward_func, @@ -535,7 +547,6 @@ def train_grpo( range(1, args.iters + 1), iterate_batches( dataset=train_dataset, - tokenizer=tokenizer, batch_size=args.batch_size, max_seq_length=args.max_seq_length, train=True, diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index d86e01dd..d3497177 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -2,7 +2,8 @@ import json import types from pathlib import Path -from typing import Dict +from typing import Dict, List +from dataclasses import dataclass import mlx.core as mx import mlx.nn as nn @@ -275,3 +276,19 @@ def print_trainable_parameters(model): f"Trainable parameters: {(trainable_p * 100 / total_p):.3f}% " f"({trainable_p:.3f}M/{total_p:.3f}M)" ) + +@dataclass +class GRPOExample: + """Single example for GRPO training/inference.""" + prompt_tokens: List[int] + answer_tokens: List[int] + prompt_text: str + answer_text: str + +@dataclass +class GRPOBatch: + """A batch of GRPO examples.""" + prompt_tokens: List[List[int]] + answer_tokens: List[List[int]] + prompt_texts: List[str] + answer_texts: List[str] \ No newline at end of file