update new iterade batches function + nits

This commit is contained in:
Goekdeniz-Guelmez 2025-02-12 08:57:26 +01:00
parent e80bf95182
commit 5aeefc8c47
3 changed files with 100 additions and 64 deletions

View File

@ -4,6 +4,7 @@ import types
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
from .utils import GRPOExample
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
@ -11,7 +12,7 @@ class GRPODataset:
""" """
Dataset wrapper for GRPO training data. Dataset wrapper for GRPO training data.
Each example should have a 'prompt' and 'answer' field. Each example should have a 'prompt' and 'answer' field.
Returns data in (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple format. Returns data as GRPOExample instances.
""" """
def __init__( def __init__(
self, self,
@ -22,10 +23,11 @@ class GRPODataset:
use_chat_template: bool = False, use_chat_template: bool = False,
use_prompt: bool = False use_prompt: bool = False
): ):
self._data = [] self._data: List[GRPOExample] = []
for item in data: for item in data:
prompt_str = str(item[prompt_key]) prompt_str = str(item[prompt_key])
answer_str = str(item[answer_key]) answer_str = str(item[answer_key])
if use_chat_template: if use_chat_template:
prompt_tokens = tokenizer.apply_chat_template( prompt_tokens = tokenizer.apply_chat_template(
[ [
@ -45,10 +47,16 @@ class GRPODataset:
else: else:
prompt_tokens = tokenizer.encode(prompt_str) prompt_tokens = tokenizer.encode(prompt_str)
answer_tokens = tokenizer.encode(answer_str) answer_tokens = tokenizer.encode(answer_str)
self._data.append((prompt_tokens, answer_tokens, prompt_str, answer_str))
def __getitem__(self, idx: int) -> Tuple[List[int], List[int], str, str]: self._data.append(GRPOExample(
"""Returns a (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple.""" prompt_tokens=prompt_tokens,
answer_tokens=answer_tokens,
prompt_text=prompt_str,
answer_text=answer_str
))
def __getitem__(self, idx: int) -> GRPOExample:
"""Returns a GRPOExample instance."""
return self._data[idx] return self._data[idx]
def __len__(self) -> int: def __len__(self) -> int:

View File

@ -2,6 +2,7 @@
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List, Iterator, Optional
from pathlib import Path from pathlib import Path
import re import re
@ -10,6 +11,7 @@ import mlx.nn as nn
import numpy as np import numpy as np
from mlx.utils import tree_flatten from mlx.utils import tree_flatten
from .utils import GRPOBatch, GRPOExample
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients, iterate_batches from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients, iterate_batches
@dataclass @dataclass
@ -109,7 +111,7 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li
return scores return scores
def generate_grpo(model, prompt, max_tokens, tokenizer, temperature): def generate_grpo(model: nn.Module, prompt, max_tokens, tokenizer, temperature):
if len(prompt.shape) == 1: if len(prompt.shape) == 1:
prompt = prompt[None, :] prompt = prompt[None, :]
if prompt.shape[1] == 0: if prompt.shape[1] == 0:
@ -117,9 +119,11 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature):
end_sequence = tokenizer.encode("</answer>") end_sequence = tokenizer.encode("</answer>")
end_sequence_length = len(end_sequence) end_sequence_length = len(end_sequence)
output = mx.zeros((prompt.shape[1] + max_tokens,), dtype=mx.int32)
output[:prompt.shape[1]] = prompt[0] initial_length = prompt.shape[1]
current_length = prompt.shape[1] output = mx.zeros((initial_length + max_tokens,), dtype=mx.int32)
output[:initial_length] = prompt[0]
current_length = initial_length
try: try:
def sample(logits): def sample(logits):
@ -145,7 +149,7 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature):
if last_tokens == end_sequence: if last_tokens == end_sequence:
break break
if current_length > prompt.shape[1]: if current_length > initial_length:
return output[:current_length] return output[:current_length]
except Exception as e: except Exception as e:
@ -155,7 +159,7 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature):
return None return None
def get_per_token_logps(model, inputs, lengths): def get_per_token_logps(model: nn.Module, inputs, lengths):
logits = model(inputs).astype(mx.float16) logits = model(inputs).astype(mx.float16)
logits = logits[:, :-1, :] logits = logits[:, :-1, :]
targets = inputs[:, 1:] targets = inputs[:, 1:]
@ -176,10 +180,10 @@ def get_per_token_logps(model, inputs, lengths):
def grpo_loss( def grpo_loss(
model, model: nn.Module,
ref_model, ref_model: Optional[nn.Module],
tokenizer, tokenizer,
batch, batch=GRPOBatch,
reward_funcs=None, reward_funcs=None,
beta=0.1, beta=0.1,
group_size=4, group_size=4,
@ -187,15 +191,18 @@ def grpo_loss(
max_tokens=64, max_tokens=64,
temperature=1.0 temperature=1.0
): ):
prompt_tokens, answer_tokens, prompt_text, answer_text = batch prompts_tokens = batch.prompt_tokens
batch_size = len(prompt_tokens) answers_tokens = batch.answer_tokens
prompts_text = batch.prompt_texts
answers_text = batch.answer_texts
batch_size = len(prompts_tokens)
# Generation logic remains the same # Generation logic remains the same
all_completions = [] all_completions = []
all_completion_texts = [] all_completion_texts = []
for i in range(0, batch_size, batch_size): for i in range(0, batch_size, batch_size):
batch_prompts = prompt_tokens[i:i+batch_size] batch_prompts = prompts_tokens[i:i+batch_size]
for prompt in batch_prompts: for prompt in batch_prompts:
prompt_tensor = mx.array(prompt) prompt_tensor = mx.array(prompt)
for _ in range(group_size): for _ in range(group_size):
@ -219,8 +226,8 @@ def grpo_loss(
expanded_answers = [] expanded_answers = []
expanded_prompts = [] expanded_prompts = []
for i in range(batch_size): for i in range(batch_size):
expanded_answers.extend([answer_text[i]] * group_size) expanded_answers.extend([answers_text[i]] * group_size)
expanded_prompts.extend([prompt_text[i]] * group_size) expanded_prompts.extend([prompts_text[i]] * group_size)
max_length = max(ids.shape[0] for ids in all_completions) max_length = max(ids.shape[0] for ids in all_completions)
padded_completions = [] padded_completions = []
@ -332,60 +339,66 @@ def grpo_loss(
return loss, sequence_lengths.sum(), metrics return loss, sequence_lengths.sum(), metrics
def iterate_grpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): def iterate_grpo_batches(
if not dataset or not isinstance(dataset[0], tuple) or len(dataset[0]) != 4: dataset: List[GRPOExample],
raise ValueError("Dataset must be list of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples") batch_size: int,
max_seq_length: int,
# Sort by length but use generator to avoid keeping full sorted list in memory train: bool = False,
def length_key(i): ) -> Iterator[GRPOBatch]:
return len(dataset[i][0]) + len(dataset[i][1])
idx = sorted(range(len(dataset)), key=length_key)
if len(dataset) < batch_size: if len(dataset) < batch_size:
raise ValueError( raise ValueError(
f"Dataset must have at least batch_size={batch_size} " f"Dataset must have at least batch_size={batch_size} "
f"examples but only has {len(dataset)}." f"examples but only has {len(dataset)}."
) )
# Get MLX distributed setup
step = mx.distributed.init().size() step = mx.distributed.init().size()
if batch_size % step != 0: if batch_size % step != 0:
raise ValueError("The batch size must be divisible by the number of workers") raise ValueError("The batch size must be divisible by the number of workers")
# Use generator for batch indices # Sort by combined length for efficient batching
def batch_index_generator(): def length_key(example: GRPOExample) -> int:
for i in range(0, len(idx) - batch_size + 1, batch_size): return len(example.prompt_tokens) + len(example.answer_tokens)
yield idx[i : i + batch_size : step]
sorted_dataset = sorted(dataset, key=length_key)
# Create batch indices
num_complete_batches = (len(dataset) - batch_size + 1) // batch_size
batch_starts = range(0, num_complete_batches * batch_size, batch_size)
while True: while True:
indices = ( # Shuffle batch start indices
np.random.permutation(list(batch_index_generator())) if train shuffled_starts = np.random.permutation(batch_starts)
else batch_index_generator()
for start_idx in shuffled_starts:
# Account for distributed setup by taking every step-th example
batch_idx = list(range(start_idx, start_idx + batch_size, step))
current_batch = [sorted_dataset[j] for j in batch_idx]
# Create batch using dataclass attributes
batch = GRPOBatch(
prompt_tokens=[ex.prompt_tokens for ex in current_batch],
answer_tokens=[ex.answer_tokens for ex in current_batch],
prompt_texts=[ex.prompt_text for ex in current_batch],
answer_texts=[ex.answer_text for ex in current_batch]
) )
for batch_idx in indices: # Check sequence lengths
current_batch = [dataset[j] for j in batch_idx] if any(len(tokens) > max_seq_length for tokens in batch.prompt_tokens):
prompts_tokens = [item[0] for item in current_batch]
answers_tokens = [item[1] for item in current_batch]
prompts_text = [item[2] for item in current_batch]
answers_text = [item[3] for item in current_batch]
if any(len(p) > max_seq_length for p in prompts_tokens):
print( print(
f"[WARNING] Some prompts are longer than {max_seq_length} tokens. " f"[WARNING] Some prompts are longer than {max_seq_length} tokens. "
"Long prompts will be truncated." "Long prompts will be truncated."
) )
yield prompts_tokens, answers_tokens, prompts_text, answers_text yield batch
if not train: if not train:
break break
def evaluate_grpo( def evaluate_grpo(
model, model: nn.Module,
ref_model, ref_model: Optional[nn.Module],
dataset, dataset,
tokenizer, tokenizer,
batch_size, batch_size,
@ -415,7 +428,6 @@ def evaluate_grpo(
index_iterator, index_iterator,
iterate_batches( iterate_batches(
dataset=dataset, dataset=dataset,
tokenizer=tokenizer,
batch_size=batch_size, batch_size=batch_size,
max_seq_length=max_seq_length, max_seq_length=max_seq_length,
), ),
@ -459,12 +471,12 @@ def evaluate_grpo(
def train_grpo( def train_grpo(
model, model: nn.Module,
ref_model, ref_model: Optional[nn.Module],
tokenizer, tokenizer,
optimizer, optimizer,
train_dataset, train_dataset: List[GRPOExample],
val_dataset, val_dataset: List[GRPOExample],
reward_funcs = [ reward_funcs = [
r1_accuracy_reward_func, r1_accuracy_reward_func,
r1_int_reward_func, r1_int_reward_func,
@ -535,7 +547,6 @@ def train_grpo(
range(1, args.iters + 1), range(1, args.iters + 1),
iterate_batches( iterate_batches(
dataset=train_dataset, dataset=train_dataset,
tokenizer=tokenizer,
batch_size=args.batch_size, batch_size=args.batch_size,
max_seq_length=args.max_seq_length, max_seq_length=args.max_seq_length,
train=True, train=True,

View File

@ -2,7 +2,8 @@
import json import json
import types import types
from pathlib import Path from pathlib import Path
from typing import Dict from typing import Dict, List
from dataclasses import dataclass
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -275,3 +276,19 @@ def print_trainable_parameters(model):
f"Trainable parameters: {(trainable_p * 100 / total_p):.3f}% " f"Trainable parameters: {(trainable_p * 100 / total_p):.3f}% "
f"({trainable_p:.3f}M/{total_p:.3f}M)" f"({trainable_p:.3f}M/{total_p:.3f}M)"
) )
@dataclass
class GRPOExample:
"""Single example for GRPO training/inference."""
prompt_tokens: List[int]
answer_tokens: List[int]
prompt_text: str
answer_text: str
@dataclass
class GRPOBatch:
"""A batch of GRPO examples."""
prompt_tokens: List[List[int]]
answer_tokens: List[List[int]]
prompt_texts: List[str]
answer_texts: List[str]