mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 02:33:23 +08:00
update new iterade batches function + nits
This commit is contained in:
parent
e80bf95182
commit
5aeefc8c47
@ -4,6 +4,7 @@ import types
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from .utils import GRPOExample
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
|
||||
@ -11,7 +12,7 @@ class GRPODataset:
|
||||
"""
|
||||
Dataset wrapper for GRPO training data.
|
||||
Each example should have a 'prompt' and 'answer' field.
|
||||
Returns data in (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple format.
|
||||
Returns data as GRPOExample instances.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
@ -22,10 +23,11 @@ class GRPODataset:
|
||||
use_chat_template: bool = False,
|
||||
use_prompt: bool = False
|
||||
):
|
||||
self._data = []
|
||||
self._data: List[GRPOExample] = []
|
||||
for item in data:
|
||||
prompt_str = str(item[prompt_key])
|
||||
answer_str = str(item[answer_key])
|
||||
|
||||
if use_chat_template:
|
||||
prompt_tokens = tokenizer.apply_chat_template(
|
||||
[
|
||||
@ -45,10 +47,16 @@ class GRPODataset:
|
||||
else:
|
||||
prompt_tokens = tokenizer.encode(prompt_str)
|
||||
answer_tokens = tokenizer.encode(answer_str)
|
||||
self._data.append((prompt_tokens, answer_tokens, prompt_str, answer_str))
|
||||
|
||||
def __getitem__(self, idx: int) -> Tuple[List[int], List[int], str, str]:
|
||||
"""Returns a (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple."""
|
||||
self._data.append(GRPOExample(
|
||||
prompt_tokens=prompt_tokens,
|
||||
answer_tokens=answer_tokens,
|
||||
prompt_text=prompt_str,
|
||||
answer_text=answer_str
|
||||
))
|
||||
|
||||
def __getitem__(self, idx: int) -> GRPOExample:
|
||||
"""Returns a GRPOExample instance."""
|
||||
return self._data[idx]
|
||||
|
||||
def __len__(self) -> int:
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Iterator, Optional
|
||||
from pathlib import Path
|
||||
import re
|
||||
|
||||
@ -10,6 +11,7 @@ import mlx.nn as nn
|
||||
import numpy as np
|
||||
from mlx.utils import tree_flatten
|
||||
|
||||
from .utils import GRPOBatch, GRPOExample
|
||||
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients, iterate_batches
|
||||
|
||||
@dataclass
|
||||
@ -109,7 +111,7 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li
|
||||
return scores
|
||||
|
||||
|
||||
def generate_grpo(model, prompt, max_tokens, tokenizer, temperature):
|
||||
def generate_grpo(model: nn.Module, prompt, max_tokens, tokenizer, temperature):
|
||||
if len(prompt.shape) == 1:
|
||||
prompt = prompt[None, :]
|
||||
if prompt.shape[1] == 0:
|
||||
@ -117,9 +119,11 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature):
|
||||
|
||||
end_sequence = tokenizer.encode("</answer>")
|
||||
end_sequence_length = len(end_sequence)
|
||||
output = mx.zeros((prompt.shape[1] + max_tokens,), dtype=mx.int32)
|
||||
output[:prompt.shape[1]] = prompt[0]
|
||||
current_length = prompt.shape[1]
|
||||
|
||||
initial_length = prompt.shape[1]
|
||||
output = mx.zeros((initial_length + max_tokens,), dtype=mx.int32)
|
||||
output[:initial_length] = prompt[0]
|
||||
current_length = initial_length
|
||||
|
||||
try:
|
||||
def sample(logits):
|
||||
@ -145,7 +149,7 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature):
|
||||
if last_tokens == end_sequence:
|
||||
break
|
||||
|
||||
if current_length > prompt.shape[1]:
|
||||
if current_length > initial_length:
|
||||
return output[:current_length]
|
||||
|
||||
except Exception as e:
|
||||
@ -155,7 +159,7 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature):
|
||||
return None
|
||||
|
||||
|
||||
def get_per_token_logps(model, inputs, lengths):
|
||||
def get_per_token_logps(model: nn.Module, inputs, lengths):
|
||||
logits = model(inputs).astype(mx.float16)
|
||||
logits = logits[:, :-1, :]
|
||||
targets = inputs[:, 1:]
|
||||
@ -176,10 +180,10 @@ def get_per_token_logps(model, inputs, lengths):
|
||||
|
||||
|
||||
def grpo_loss(
|
||||
model,
|
||||
ref_model,
|
||||
model: nn.Module,
|
||||
ref_model: Optional[nn.Module],
|
||||
tokenizer,
|
||||
batch,
|
||||
batch=GRPOBatch,
|
||||
reward_funcs=None,
|
||||
beta=0.1,
|
||||
group_size=4,
|
||||
@ -187,15 +191,18 @@ def grpo_loss(
|
||||
max_tokens=64,
|
||||
temperature=1.0
|
||||
):
|
||||
prompt_tokens, answer_tokens, prompt_text, answer_text = batch
|
||||
batch_size = len(prompt_tokens)
|
||||
prompts_tokens = batch.prompt_tokens
|
||||
answers_tokens = batch.answer_tokens
|
||||
prompts_text = batch.prompt_texts
|
||||
answers_text = batch.answer_texts
|
||||
batch_size = len(prompts_tokens)
|
||||
|
||||
# Generation logic remains the same
|
||||
all_completions = []
|
||||
all_completion_texts = []
|
||||
|
||||
for i in range(0, batch_size, batch_size):
|
||||
batch_prompts = prompt_tokens[i:i+batch_size]
|
||||
batch_prompts = prompts_tokens[i:i+batch_size]
|
||||
for prompt in batch_prompts:
|
||||
prompt_tensor = mx.array(prompt)
|
||||
for _ in range(group_size):
|
||||
@ -219,8 +226,8 @@ def grpo_loss(
|
||||
expanded_answers = []
|
||||
expanded_prompts = []
|
||||
for i in range(batch_size):
|
||||
expanded_answers.extend([answer_text[i]] * group_size)
|
||||
expanded_prompts.extend([prompt_text[i]] * group_size)
|
||||
expanded_answers.extend([answers_text[i]] * group_size)
|
||||
expanded_prompts.extend([prompts_text[i]] * group_size)
|
||||
|
||||
max_length = max(ids.shape[0] for ids in all_completions)
|
||||
padded_completions = []
|
||||
@ -332,60 +339,66 @@ def grpo_loss(
|
||||
return loss, sequence_lengths.sum(), metrics
|
||||
|
||||
|
||||
def iterate_grpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
||||
if not dataset or not isinstance(dataset[0], tuple) or len(dataset[0]) != 4:
|
||||
raise ValueError("Dataset must be list of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples")
|
||||
|
||||
# Sort by length but use generator to avoid keeping full sorted list in memory
|
||||
def length_key(i):
|
||||
return len(dataset[i][0]) + len(dataset[i][1])
|
||||
|
||||
idx = sorted(range(len(dataset)), key=length_key)
|
||||
|
||||
def iterate_grpo_batches(
|
||||
dataset: List[GRPOExample],
|
||||
batch_size: int,
|
||||
max_seq_length: int,
|
||||
train: bool = False,
|
||||
) -> Iterator[GRPOBatch]:
|
||||
if len(dataset) < batch_size:
|
||||
raise ValueError(
|
||||
f"Dataset must have at least batch_size={batch_size} "
|
||||
f"examples but only has {len(dataset)}."
|
||||
)
|
||||
|
||||
# Get MLX distributed setup
|
||||
step = mx.distributed.init().size()
|
||||
if batch_size % step != 0:
|
||||
raise ValueError("The batch size must be divisible by the number of workers")
|
||||
|
||||
# Use generator for batch indices
|
||||
def batch_index_generator():
|
||||
for i in range(0, len(idx) - batch_size + 1, batch_size):
|
||||
yield idx[i : i + batch_size : step]
|
||||
# Sort by combined length for efficient batching
|
||||
def length_key(example: GRPOExample) -> int:
|
||||
return len(example.prompt_tokens) + len(example.answer_tokens)
|
||||
|
||||
sorted_dataset = sorted(dataset, key=length_key)
|
||||
|
||||
# Create batch indices
|
||||
num_complete_batches = (len(dataset) - batch_size + 1) // batch_size
|
||||
batch_starts = range(0, num_complete_batches * batch_size, batch_size)
|
||||
|
||||
while True:
|
||||
indices = (
|
||||
np.random.permutation(list(batch_index_generator())) if train
|
||||
else batch_index_generator()
|
||||
# Shuffle batch start indices
|
||||
shuffled_starts = np.random.permutation(batch_starts)
|
||||
|
||||
for start_idx in shuffled_starts:
|
||||
# Account for distributed setup by taking every step-th example
|
||||
batch_idx = list(range(start_idx, start_idx + batch_size, step))
|
||||
current_batch = [sorted_dataset[j] for j in batch_idx]
|
||||
|
||||
# Create batch using dataclass attributes
|
||||
batch = GRPOBatch(
|
||||
prompt_tokens=[ex.prompt_tokens for ex in current_batch],
|
||||
answer_tokens=[ex.answer_tokens for ex in current_batch],
|
||||
prompt_texts=[ex.prompt_text for ex in current_batch],
|
||||
answer_texts=[ex.answer_text for ex in current_batch]
|
||||
)
|
||||
|
||||
for batch_idx in indices:
|
||||
current_batch = [dataset[j] for j in batch_idx]
|
||||
|
||||
prompts_tokens = [item[0] for item in current_batch]
|
||||
answers_tokens = [item[1] for item in current_batch]
|
||||
prompts_text = [item[2] for item in current_batch]
|
||||
answers_text = [item[3] for item in current_batch]
|
||||
|
||||
if any(len(p) > max_seq_length for p in prompts_tokens):
|
||||
# Check sequence lengths
|
||||
if any(len(tokens) > max_seq_length for tokens in batch.prompt_tokens):
|
||||
print(
|
||||
f"[WARNING] Some prompts are longer than {max_seq_length} tokens. "
|
||||
"Long prompts will be truncated."
|
||||
)
|
||||
|
||||
yield prompts_tokens, answers_tokens, prompts_text, answers_text
|
||||
yield batch
|
||||
|
||||
if not train:
|
||||
break
|
||||
|
||||
|
||||
def evaluate_grpo(
|
||||
model,
|
||||
ref_model,
|
||||
model: nn.Module,
|
||||
ref_model: Optional[nn.Module],
|
||||
dataset,
|
||||
tokenizer,
|
||||
batch_size,
|
||||
@ -415,7 +428,6 @@ def evaluate_grpo(
|
||||
index_iterator,
|
||||
iterate_batches(
|
||||
dataset=dataset,
|
||||
tokenizer=tokenizer,
|
||||
batch_size=batch_size,
|
||||
max_seq_length=max_seq_length,
|
||||
),
|
||||
@ -459,12 +471,12 @@ def evaluate_grpo(
|
||||
|
||||
|
||||
def train_grpo(
|
||||
model,
|
||||
ref_model,
|
||||
model: nn.Module,
|
||||
ref_model: Optional[nn.Module],
|
||||
tokenizer,
|
||||
optimizer,
|
||||
train_dataset,
|
||||
val_dataset,
|
||||
train_dataset: List[GRPOExample],
|
||||
val_dataset: List[GRPOExample],
|
||||
reward_funcs = [
|
||||
r1_accuracy_reward_func,
|
||||
r1_int_reward_func,
|
||||
@ -535,7 +547,6 @@ def train_grpo(
|
||||
range(1, args.iters + 1),
|
||||
iterate_batches(
|
||||
dataset=train_dataset,
|
||||
tokenizer=tokenizer,
|
||||
batch_size=args.batch_size,
|
||||
max_seq_length=args.max_seq_length,
|
||||
train=True,
|
||||
|
@ -2,7 +2,8 @@
|
||||
import json
|
||||
import types
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from typing import Dict, List
|
||||
from dataclasses import dataclass
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@ -275,3 +276,19 @@ def print_trainable_parameters(model):
|
||||
f"Trainable parameters: {(trainable_p * 100 / total_p):.3f}% "
|
||||
f"({trainable_p:.3f}M/{total_p:.3f}M)"
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class GRPOExample:
|
||||
"""Single example for GRPO training/inference."""
|
||||
prompt_tokens: List[int]
|
||||
answer_tokens: List[int]
|
||||
prompt_text: str
|
||||
answer_text: str
|
||||
|
||||
@dataclass
|
||||
class GRPOBatch:
|
||||
"""A batch of GRPO examples."""
|
||||
prompt_tokens: List[List[int]]
|
||||
answer_tokens: List[List[int]]
|
||||
prompt_texts: List[str]
|
||||
answer_texts: List[str]
|
Loading…
Reference in New Issue
Block a user