update new iterade batches function + nits

This commit is contained in:
Goekdeniz-Guelmez 2025-02-12 08:57:26 +01:00
parent e80bf95182
commit 5aeefc8c47
3 changed files with 100 additions and 64 deletions

View File

@ -4,6 +4,7 @@ import types
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from .utils import GRPOExample
from transformers import PreTrainedTokenizer
@ -11,7 +12,7 @@ class GRPODataset:
"""
Dataset wrapper for GRPO training data.
Each example should have a 'prompt' and 'answer' field.
Returns data in (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple format.
Returns data as GRPOExample instances.
"""
def __init__(
self,
@ -22,10 +23,11 @@ class GRPODataset:
use_chat_template: bool = False,
use_prompt: bool = False
):
self._data = []
self._data: List[GRPOExample] = []
for item in data:
prompt_str = str(item[prompt_key])
answer_str = str(item[answer_key])
if use_chat_template:
prompt_tokens = tokenizer.apply_chat_template(
[
@ -45,10 +47,16 @@ class GRPODataset:
else:
prompt_tokens = tokenizer.encode(prompt_str)
answer_tokens = tokenizer.encode(answer_str)
self._data.append((prompt_tokens, answer_tokens, prompt_str, answer_str))
def __getitem__(self, idx: int) -> Tuple[List[int], List[int], str, str]:
"""Returns a (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple."""
self._data.append(GRPOExample(
prompt_tokens=prompt_tokens,
answer_tokens=answer_tokens,
prompt_text=prompt_str,
answer_text=answer_str
))
def __getitem__(self, idx: int) -> GRPOExample:
"""Returns a GRPOExample instance."""
return self._data[idx]
def __len__(self) -> int:

View File

@ -2,6 +2,7 @@
import time
from dataclasses import dataclass, field
from typing import List, Iterator, Optional
from pathlib import Path
import re
@ -10,6 +11,7 @@ import mlx.nn as nn
import numpy as np
from mlx.utils import tree_flatten
from .utils import GRPOBatch, GRPOExample
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients, iterate_batches
@dataclass
@ -109,7 +111,7 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li
return scores
def generate_grpo(model, prompt, max_tokens, tokenizer, temperature):
def generate_grpo(model: nn.Module, prompt, max_tokens, tokenizer, temperature):
if len(prompt.shape) == 1:
prompt = prompt[None, :]
if prompt.shape[1] == 0:
@ -117,9 +119,11 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature):
end_sequence = tokenizer.encode("</answer>")
end_sequence_length = len(end_sequence)
output = mx.zeros((prompt.shape[1] + max_tokens,), dtype=mx.int32)
output[:prompt.shape[1]] = prompt[0]
current_length = prompt.shape[1]
initial_length = prompt.shape[1]
output = mx.zeros((initial_length + max_tokens,), dtype=mx.int32)
output[:initial_length] = prompt[0]
current_length = initial_length
try:
def sample(logits):
@ -145,7 +149,7 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature):
if last_tokens == end_sequence:
break
if current_length > prompt.shape[1]:
if current_length > initial_length:
return output[:current_length]
except Exception as e:
@ -155,7 +159,7 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature):
return None
def get_per_token_logps(model, inputs, lengths):
def get_per_token_logps(model: nn.Module, inputs, lengths):
logits = model(inputs).astype(mx.float16)
logits = logits[:, :-1, :]
targets = inputs[:, 1:]
@ -176,10 +180,10 @@ def get_per_token_logps(model, inputs, lengths):
def grpo_loss(
model,
ref_model,
model: nn.Module,
ref_model: Optional[nn.Module],
tokenizer,
batch,
batch=GRPOBatch,
reward_funcs=None,
beta=0.1,
group_size=4,
@ -187,15 +191,18 @@ def grpo_loss(
max_tokens=64,
temperature=1.0
):
prompt_tokens, answer_tokens, prompt_text, answer_text = batch
batch_size = len(prompt_tokens)
prompts_tokens = batch.prompt_tokens
answers_tokens = batch.answer_tokens
prompts_text = batch.prompt_texts
answers_text = batch.answer_texts
batch_size = len(prompts_tokens)
# Generation logic remains the same
all_completions = []
all_completion_texts = []
for i in range(0, batch_size, batch_size):
batch_prompts = prompt_tokens[i:i+batch_size]
batch_prompts = prompts_tokens[i:i+batch_size]
for prompt in batch_prompts:
prompt_tensor = mx.array(prompt)
for _ in range(group_size):
@ -219,8 +226,8 @@ def grpo_loss(
expanded_answers = []
expanded_prompts = []
for i in range(batch_size):
expanded_answers.extend([answer_text[i]] * group_size)
expanded_prompts.extend([prompt_text[i]] * group_size)
expanded_answers.extend([answers_text[i]] * group_size)
expanded_prompts.extend([prompts_text[i]] * group_size)
max_length = max(ids.shape[0] for ids in all_completions)
padded_completions = []
@ -332,60 +339,66 @@ def grpo_loss(
return loss, sequence_lengths.sum(), metrics
def iterate_grpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
if not dataset or not isinstance(dataset[0], tuple) or len(dataset[0]) != 4:
raise ValueError("Dataset must be list of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples")
# Sort by length but use generator to avoid keeping full sorted list in memory
def length_key(i):
return len(dataset[i][0]) + len(dataset[i][1])
idx = sorted(range(len(dataset)), key=length_key)
def iterate_grpo_batches(
dataset: List[GRPOExample],
batch_size: int,
max_seq_length: int,
train: bool = False,
) -> Iterator[GRPOBatch]:
if len(dataset) < batch_size:
raise ValueError(
f"Dataset must have at least batch_size={batch_size} "
f"examples but only has {len(dataset)}."
)
# Get MLX distributed setup
step = mx.distributed.init().size()
if batch_size % step != 0:
raise ValueError("The batch size must be divisible by the number of workers")
# Use generator for batch indices
def batch_index_generator():
for i in range(0, len(idx) - batch_size + 1, batch_size):
yield idx[i : i + batch_size : step]
# Sort by combined length for efficient batching
def length_key(example: GRPOExample) -> int:
return len(example.prompt_tokens) + len(example.answer_tokens)
sorted_dataset = sorted(dataset, key=length_key)
# Create batch indices
num_complete_batches = (len(dataset) - batch_size + 1) // batch_size
batch_starts = range(0, num_complete_batches * batch_size, batch_size)
while True:
indices = (
np.random.permutation(list(batch_index_generator())) if train
else batch_index_generator()
# Shuffle batch start indices
shuffled_starts = np.random.permutation(batch_starts)
for start_idx in shuffled_starts:
# Account for distributed setup by taking every step-th example
batch_idx = list(range(start_idx, start_idx + batch_size, step))
current_batch = [sorted_dataset[j] for j in batch_idx]
# Create batch using dataclass attributes
batch = GRPOBatch(
prompt_tokens=[ex.prompt_tokens for ex in current_batch],
answer_tokens=[ex.answer_tokens for ex in current_batch],
prompt_texts=[ex.prompt_text for ex in current_batch],
answer_texts=[ex.answer_text for ex in current_batch]
)
for batch_idx in indices:
current_batch = [dataset[j] for j in batch_idx]
prompts_tokens = [item[0] for item in current_batch]
answers_tokens = [item[1] for item in current_batch]
prompts_text = [item[2] for item in current_batch]
answers_text = [item[3] for item in current_batch]
if any(len(p) > max_seq_length for p in prompts_tokens):
# Check sequence lengths
if any(len(tokens) > max_seq_length for tokens in batch.prompt_tokens):
print(
f"[WARNING] Some prompts are longer than {max_seq_length} tokens. "
"Long prompts will be truncated."
)
yield prompts_tokens, answers_tokens, prompts_text, answers_text
yield batch
if not train:
break
def evaluate_grpo(
model,
ref_model,
model: nn.Module,
ref_model: Optional[nn.Module],
dataset,
tokenizer,
batch_size,
@ -415,7 +428,6 @@ def evaluate_grpo(
index_iterator,
iterate_batches(
dataset=dataset,
tokenizer=tokenizer,
batch_size=batch_size,
max_seq_length=max_seq_length,
),
@ -459,12 +471,12 @@ def evaluate_grpo(
def train_grpo(
model,
ref_model,
model: nn.Module,
ref_model: Optional[nn.Module],
tokenizer,
optimizer,
train_dataset,
val_dataset,
train_dataset: List[GRPOExample],
val_dataset: List[GRPOExample],
reward_funcs = [
r1_accuracy_reward_func,
r1_int_reward_func,
@ -535,7 +547,6 @@ def train_grpo(
range(1, args.iters + 1),
iterate_batches(
dataset=train_dataset,
tokenizer=tokenizer,
batch_size=args.batch_size,
max_seq_length=args.max_seq_length,
train=True,

View File

@ -2,7 +2,8 @@
import json
import types
from pathlib import Path
from typing import Dict
from typing import Dict, List
from dataclasses import dataclass
import mlx.core as mx
import mlx.nn as nn
@ -275,3 +276,19 @@ def print_trainable_parameters(model):
f"Trainable parameters: {(trainable_p * 100 / total_p):.3f}% "
f"({trainable_p:.3f}M/{total_p:.3f}M)"
)
@dataclass
class GRPOExample:
"""Single example for GRPO training/inference."""
prompt_tokens: List[int]
answer_tokens: List[int]
prompt_text: str
answer_text: str
@dataclass
class GRPOBatch:
"""A batch of GRPO examples."""
prompt_tokens: List[List[int]]
answer_tokens: List[List[int]]
prompt_texts: List[str]
answer_texts: List[str]