update new iterade batches function + nits

This commit is contained in:
Goekdeniz-Guelmez 2025-02-12 08:57:26 +01:00
parent e80bf95182
commit 5aeefc8c47
3 changed files with 100 additions and 64 deletions

View File

@ -4,6 +4,7 @@ import types
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from .utils import GRPOExample
from transformers import PreTrainedTokenizer
@ -11,7 +12,7 @@ class GRPODataset:
"""
Dataset wrapper for GRPO training data.
Each example should have a 'prompt' and 'answer' field.
Returns data in (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple format.
Returns data as GRPOExample instances.
"""
def __init__(
self,
@ -22,33 +23,40 @@ class GRPODataset:
use_chat_template: bool = False,
use_prompt: bool = False
):
self._data = []
self._data: List[GRPOExample] = []
for item in data:
prompt_str = str(item[prompt_key])
answer_str = str(item[answer_key])
if use_chat_template:
prompt_tokens = tokenizer.apply_chat_template(
[
{'role': 'system', 'content': """A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer.
The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>."""},
{'role': 'user', 'content': prompt_str}
The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer.
The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>."""},
{'role': 'user', 'content': prompt_str}
],
)
answer_tokens = tokenizer.encode(answer_str)
else:
if use_prompt:
prompt_tokens = tokenizer.encode(f"""A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer.
The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>.
User: {prompt_str} Assistant: """)
The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer.
The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>.
User: {prompt_str} Assistant: """)
else:
prompt_tokens = tokenizer.encode(prompt_str)
answer_tokens = tokenizer.encode(answer_str)
self._data.append((prompt_tokens, answer_tokens, prompt_str, answer_str))
def __getitem__(self, idx: int) -> Tuple[List[int], List[int], str, str]:
"""Returns a (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple."""
self._data.append(GRPOExample(
prompt_tokens=prompt_tokens,
answer_tokens=answer_tokens,
prompt_text=prompt_str,
answer_text=answer_str
))
def __getitem__(self, idx: int) -> GRPOExample:
"""Returns a GRPOExample instance."""
return self._data[idx]
def __len__(self) -> int:

View File

@ -2,6 +2,7 @@
import time
from dataclasses import dataclass, field
from typing import List, Iterator, Optional
from pathlib import Path
import re
@ -10,6 +11,7 @@ import mlx.nn as nn
import numpy as np
from mlx.utils import tree_flatten
from .utils import GRPOBatch, GRPOExample
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients, iterate_batches
@dataclass
@ -109,7 +111,7 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li
return scores
def generate_grpo(model, prompt, max_tokens, tokenizer, temperature):
def generate_grpo(model: nn.Module, prompt, max_tokens, tokenizer, temperature):
if len(prompt.shape) == 1:
prompt = prompt[None, :]
if prompt.shape[1] == 0:
@ -117,9 +119,11 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature):
end_sequence = tokenizer.encode("</answer>")
end_sequence_length = len(end_sequence)
output = mx.zeros((prompt.shape[1] + max_tokens,), dtype=mx.int32)
output[:prompt.shape[1]] = prompt[0]
current_length = prompt.shape[1]
initial_length = prompt.shape[1]
output = mx.zeros((initial_length + max_tokens,), dtype=mx.int32)
output[:initial_length] = prompt[0]
current_length = initial_length
try:
def sample(logits):
@ -145,7 +149,7 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature):
if last_tokens == end_sequence:
break
if current_length > prompt.shape[1]:
if current_length > initial_length:
return output[:current_length]
except Exception as e:
@ -155,7 +159,7 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature):
return None
def get_per_token_logps(model, inputs, lengths):
def get_per_token_logps(model: nn.Module, inputs, lengths):
logits = model(inputs).astype(mx.float16)
logits = logits[:, :-1, :]
targets = inputs[:, 1:]
@ -176,10 +180,10 @@ def get_per_token_logps(model, inputs, lengths):
def grpo_loss(
model,
ref_model,
model: nn.Module,
ref_model: Optional[nn.Module],
tokenizer,
batch,
batch=GRPOBatch,
reward_funcs=None,
beta=0.1,
group_size=4,
@ -187,15 +191,18 @@ def grpo_loss(
max_tokens=64,
temperature=1.0
):
prompt_tokens, answer_tokens, prompt_text, answer_text = batch
batch_size = len(prompt_tokens)
prompts_tokens = batch.prompt_tokens
answers_tokens = batch.answer_tokens
prompts_text = batch.prompt_texts
answers_text = batch.answer_texts
batch_size = len(prompts_tokens)
# Generation logic remains the same
all_completions = []
all_completion_texts = []
for i in range(0, batch_size, batch_size):
batch_prompts = prompt_tokens[i:i+batch_size]
batch_prompts = prompts_tokens[i:i+batch_size]
for prompt in batch_prompts:
prompt_tensor = mx.array(prompt)
for _ in range(group_size):
@ -219,8 +226,8 @@ def grpo_loss(
expanded_answers = []
expanded_prompts = []
for i in range(batch_size):
expanded_answers.extend([answer_text[i]] * group_size)
expanded_prompts.extend([prompt_text[i]] * group_size)
expanded_answers.extend([answers_text[i]] * group_size)
expanded_prompts.extend([prompts_text[i]] * group_size)
max_length = max(ids.shape[0] for ids in all_completions)
padded_completions = []
@ -332,60 +339,66 @@ def grpo_loss(
return loss, sequence_lengths.sum(), metrics
def iterate_grpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
if not dataset or not isinstance(dataset[0], tuple) or len(dataset[0]) != 4:
raise ValueError("Dataset must be list of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples")
# Sort by length but use generator to avoid keeping full sorted list in memory
def length_key(i):
return len(dataset[i][0]) + len(dataset[i][1])
idx = sorted(range(len(dataset)), key=length_key)
def iterate_grpo_batches(
dataset: List[GRPOExample],
batch_size: int,
max_seq_length: int,
train: bool = False,
) -> Iterator[GRPOBatch]:
if len(dataset) < batch_size:
raise ValueError(
f"Dataset must have at least batch_size={batch_size} "
f"examples but only has {len(dataset)}."
)
# Get MLX distributed setup
step = mx.distributed.init().size()
if batch_size % step != 0:
raise ValueError("The batch size must be divisible by the number of workers")
# Use generator for batch indices
def batch_index_generator():
for i in range(0, len(idx) - batch_size + 1, batch_size):
yield idx[i : i + batch_size : step]
# Sort by combined length for efficient batching
def length_key(example: GRPOExample) -> int:
return len(example.prompt_tokens) + len(example.answer_tokens)
sorted_dataset = sorted(dataset, key=length_key)
# Create batch indices
num_complete_batches = (len(dataset) - batch_size + 1) // batch_size
batch_starts = range(0, num_complete_batches * batch_size, batch_size)
while True:
indices = (
np.random.permutation(list(batch_index_generator())) if train
else batch_index_generator()
)
# Shuffle batch start indices
shuffled_starts = np.random.permutation(batch_starts)
for batch_idx in indices:
current_batch = [dataset[j] for j in batch_idx]
for start_idx in shuffled_starts:
# Account for distributed setup by taking every step-th example
batch_idx = list(range(start_idx, start_idx + batch_size, step))
current_batch = [sorted_dataset[j] for j in batch_idx]
prompts_tokens = [item[0] for item in current_batch]
answers_tokens = [item[1] for item in current_batch]
prompts_text = [item[2] for item in current_batch]
answers_text = [item[3] for item in current_batch]
# Create batch using dataclass attributes
batch = GRPOBatch(
prompt_tokens=[ex.prompt_tokens for ex in current_batch],
answer_tokens=[ex.answer_tokens for ex in current_batch],
prompt_texts=[ex.prompt_text for ex in current_batch],
answer_texts=[ex.answer_text for ex in current_batch]
)
if any(len(p) > max_seq_length for p in prompts_tokens):
# Check sequence lengths
if any(len(tokens) > max_seq_length for tokens in batch.prompt_tokens):
print(
f"[WARNING] Some prompts are longer than {max_seq_length} tokens. "
"Long prompts will be truncated."
)
yield prompts_tokens, answers_tokens, prompts_text, answers_text
yield batch
if not train:
break
def evaluate_grpo(
model,
ref_model,
model: nn.Module,
ref_model: Optional[nn.Module],
dataset,
tokenizer,
batch_size,
@ -415,7 +428,6 @@ def evaluate_grpo(
index_iterator,
iterate_batches(
dataset=dataset,
tokenizer=tokenizer,
batch_size=batch_size,
max_seq_length=max_seq_length,
),
@ -459,12 +471,12 @@ def evaluate_grpo(
def train_grpo(
model,
ref_model,
model: nn.Module,
ref_model: Optional[nn.Module],
tokenizer,
optimizer,
train_dataset,
val_dataset,
train_dataset: List[GRPOExample],
val_dataset: List[GRPOExample],
reward_funcs = [
r1_accuracy_reward_func,
r1_int_reward_func,
@ -535,7 +547,6 @@ def train_grpo(
range(1, args.iters + 1),
iterate_batches(
dataset=train_dataset,
tokenizer=tokenizer,
batch_size=args.batch_size,
max_seq_length=args.max_seq_length,
train=True,

View File

@ -2,7 +2,8 @@
import json
import types
from pathlib import Path
from typing import Dict
from typing import Dict, List
from dataclasses import dataclass
import mlx.core as mx
import mlx.nn as nn
@ -275,3 +276,19 @@ def print_trainable_parameters(model):
f"Trainable parameters: {(trainable_p * 100 / total_p):.3f}% "
f"({trainable_p:.3f}M/{total_p:.3f}M)"
)
@dataclass
class GRPOExample:
"""Single example for GRPO training/inference."""
prompt_tokens: List[int]
answer_tokens: List[int]
prompt_text: str
answer_text: str
@dataclass
class GRPOBatch:
"""A batch of GRPO examples."""
prompt_tokens: List[List[int]]
answer_tokens: List[List[int]]
prompt_texts: List[str]
answer_texts: List[str]