mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 18:11:17 +08:00
update new iterade batches function + nits
This commit is contained in:
parent
e80bf95182
commit
5aeefc8c47
@ -4,6 +4,7 @@ import types
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
from .utils import GRPOExample
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
|
|
||||||
@ -11,7 +12,7 @@ class GRPODataset:
|
|||||||
"""
|
"""
|
||||||
Dataset wrapper for GRPO training data.
|
Dataset wrapper for GRPO training data.
|
||||||
Each example should have a 'prompt' and 'answer' field.
|
Each example should have a 'prompt' and 'answer' field.
|
||||||
Returns data in (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple format.
|
Returns data as GRPOExample instances.
|
||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -22,33 +23,40 @@ class GRPODataset:
|
|||||||
use_chat_template: bool = False,
|
use_chat_template: bool = False,
|
||||||
use_prompt: bool = False
|
use_prompt: bool = False
|
||||||
):
|
):
|
||||||
self._data = []
|
self._data: List[GRPOExample] = []
|
||||||
for item in data:
|
for item in data:
|
||||||
prompt_str = str(item[prompt_key])
|
prompt_str = str(item[prompt_key])
|
||||||
answer_str = str(item[answer_key])
|
answer_str = str(item[answer_key])
|
||||||
|
|
||||||
if use_chat_template:
|
if use_chat_template:
|
||||||
prompt_tokens = tokenizer.apply_chat_template(
|
prompt_tokens = tokenizer.apply_chat_template(
|
||||||
[
|
[
|
||||||
{'role': 'system', 'content': """A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
|
{'role': 'system', 'content': """A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
|
||||||
The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer.
|
The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer.
|
||||||
The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>."""},
|
The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>."""},
|
||||||
{'role': 'user', 'content': prompt_str}
|
{'role': 'user', 'content': prompt_str}
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
answer_tokens = tokenizer.encode(answer_str)
|
answer_tokens = tokenizer.encode(answer_str)
|
||||||
else:
|
else:
|
||||||
if use_prompt:
|
if use_prompt:
|
||||||
prompt_tokens = tokenizer.encode(f"""A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
|
prompt_tokens = tokenizer.encode(f"""A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
|
||||||
The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer.
|
The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer.
|
||||||
The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>.
|
The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>.
|
||||||
User: {prompt_str} Assistant: """)
|
User: {prompt_str} Assistant: """)
|
||||||
else:
|
else:
|
||||||
prompt_tokens = tokenizer.encode(prompt_str)
|
prompt_tokens = tokenizer.encode(prompt_str)
|
||||||
answer_tokens = tokenizer.encode(answer_str)
|
answer_tokens = tokenizer.encode(answer_str)
|
||||||
self._data.append((prompt_tokens, answer_tokens, prompt_str, answer_str))
|
|
||||||
|
self._data.append(GRPOExample(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
answer_tokens=answer_tokens,
|
||||||
|
prompt_text=prompt_str,
|
||||||
|
answer_text=answer_str
|
||||||
|
))
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> Tuple[List[int], List[int], str, str]:
|
def __getitem__(self, idx: int) -> GRPOExample:
|
||||||
"""Returns a (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple."""
|
"""Returns a GRPOExample instance."""
|
||||||
return self._data[idx]
|
return self._data[idx]
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from typing import List, Iterator, Optional
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import re
|
import re
|
||||||
|
|
||||||
@ -10,6 +11,7 @@ import mlx.nn as nn
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from mlx.utils import tree_flatten
|
from mlx.utils import tree_flatten
|
||||||
|
|
||||||
|
from .utils import GRPOBatch, GRPOExample
|
||||||
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients, iterate_batches
|
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients, iterate_batches
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -109,7 +111,7 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li
|
|||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
def generate_grpo(model, prompt, max_tokens, tokenizer, temperature):
|
def generate_grpo(model: nn.Module, prompt, max_tokens, tokenizer, temperature):
|
||||||
if len(prompt.shape) == 1:
|
if len(prompt.shape) == 1:
|
||||||
prompt = prompt[None, :]
|
prompt = prompt[None, :]
|
||||||
if prompt.shape[1] == 0:
|
if prompt.shape[1] == 0:
|
||||||
@ -117,9 +119,11 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature):
|
|||||||
|
|
||||||
end_sequence = tokenizer.encode("</answer>")
|
end_sequence = tokenizer.encode("</answer>")
|
||||||
end_sequence_length = len(end_sequence)
|
end_sequence_length = len(end_sequence)
|
||||||
output = mx.zeros((prompt.shape[1] + max_tokens,), dtype=mx.int32)
|
|
||||||
output[:prompt.shape[1]] = prompt[0]
|
initial_length = prompt.shape[1]
|
||||||
current_length = prompt.shape[1]
|
output = mx.zeros((initial_length + max_tokens,), dtype=mx.int32)
|
||||||
|
output[:initial_length] = prompt[0]
|
||||||
|
current_length = initial_length
|
||||||
|
|
||||||
try:
|
try:
|
||||||
def sample(logits):
|
def sample(logits):
|
||||||
@ -145,7 +149,7 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature):
|
|||||||
if last_tokens == end_sequence:
|
if last_tokens == end_sequence:
|
||||||
break
|
break
|
||||||
|
|
||||||
if current_length > prompt.shape[1]:
|
if current_length > initial_length:
|
||||||
return output[:current_length]
|
return output[:current_length]
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -155,7 +159,7 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_per_token_logps(model, inputs, lengths):
|
def get_per_token_logps(model: nn.Module, inputs, lengths):
|
||||||
logits = model(inputs).astype(mx.float16)
|
logits = model(inputs).astype(mx.float16)
|
||||||
logits = logits[:, :-1, :]
|
logits = logits[:, :-1, :]
|
||||||
targets = inputs[:, 1:]
|
targets = inputs[:, 1:]
|
||||||
@ -176,10 +180,10 @@ def get_per_token_logps(model, inputs, lengths):
|
|||||||
|
|
||||||
|
|
||||||
def grpo_loss(
|
def grpo_loss(
|
||||||
model,
|
model: nn.Module,
|
||||||
ref_model,
|
ref_model: Optional[nn.Module],
|
||||||
tokenizer,
|
tokenizer,
|
||||||
batch,
|
batch=GRPOBatch,
|
||||||
reward_funcs=None,
|
reward_funcs=None,
|
||||||
beta=0.1,
|
beta=0.1,
|
||||||
group_size=4,
|
group_size=4,
|
||||||
@ -187,15 +191,18 @@ def grpo_loss(
|
|||||||
max_tokens=64,
|
max_tokens=64,
|
||||||
temperature=1.0
|
temperature=1.0
|
||||||
):
|
):
|
||||||
prompt_tokens, answer_tokens, prompt_text, answer_text = batch
|
prompts_tokens = batch.prompt_tokens
|
||||||
batch_size = len(prompt_tokens)
|
answers_tokens = batch.answer_tokens
|
||||||
|
prompts_text = batch.prompt_texts
|
||||||
|
answers_text = batch.answer_texts
|
||||||
|
batch_size = len(prompts_tokens)
|
||||||
|
|
||||||
# Generation logic remains the same
|
# Generation logic remains the same
|
||||||
all_completions = []
|
all_completions = []
|
||||||
all_completion_texts = []
|
all_completion_texts = []
|
||||||
|
|
||||||
for i in range(0, batch_size, batch_size):
|
for i in range(0, batch_size, batch_size):
|
||||||
batch_prompts = prompt_tokens[i:i+batch_size]
|
batch_prompts = prompts_tokens[i:i+batch_size]
|
||||||
for prompt in batch_prompts:
|
for prompt in batch_prompts:
|
||||||
prompt_tensor = mx.array(prompt)
|
prompt_tensor = mx.array(prompt)
|
||||||
for _ in range(group_size):
|
for _ in range(group_size):
|
||||||
@ -219,8 +226,8 @@ def grpo_loss(
|
|||||||
expanded_answers = []
|
expanded_answers = []
|
||||||
expanded_prompts = []
|
expanded_prompts = []
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
expanded_answers.extend([answer_text[i]] * group_size)
|
expanded_answers.extend([answers_text[i]] * group_size)
|
||||||
expanded_prompts.extend([prompt_text[i]] * group_size)
|
expanded_prompts.extend([prompts_text[i]] * group_size)
|
||||||
|
|
||||||
max_length = max(ids.shape[0] for ids in all_completions)
|
max_length = max(ids.shape[0] for ids in all_completions)
|
||||||
padded_completions = []
|
padded_completions = []
|
||||||
@ -332,60 +339,66 @@ def grpo_loss(
|
|||||||
return loss, sequence_lengths.sum(), metrics
|
return loss, sequence_lengths.sum(), metrics
|
||||||
|
|
||||||
|
|
||||||
def iterate_grpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
def iterate_grpo_batches(
|
||||||
if not dataset or not isinstance(dataset[0], tuple) or len(dataset[0]) != 4:
|
dataset: List[GRPOExample],
|
||||||
raise ValueError("Dataset must be list of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples")
|
batch_size: int,
|
||||||
|
max_seq_length: int,
|
||||||
# Sort by length but use generator to avoid keeping full sorted list in memory
|
train: bool = False,
|
||||||
def length_key(i):
|
) -> Iterator[GRPOBatch]:
|
||||||
return len(dataset[i][0]) + len(dataset[i][1])
|
|
||||||
|
|
||||||
idx = sorted(range(len(dataset)), key=length_key)
|
|
||||||
|
|
||||||
if len(dataset) < batch_size:
|
if len(dataset) < batch_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Dataset must have at least batch_size={batch_size} "
|
f"Dataset must have at least batch_size={batch_size} "
|
||||||
f"examples but only has {len(dataset)}."
|
f"examples but only has {len(dataset)}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Get MLX distributed setup
|
||||||
step = mx.distributed.init().size()
|
step = mx.distributed.init().size()
|
||||||
if batch_size % step != 0:
|
if batch_size % step != 0:
|
||||||
raise ValueError("The batch size must be divisible by the number of workers")
|
raise ValueError("The batch size must be divisible by the number of workers")
|
||||||
|
|
||||||
# Use generator for batch indices
|
# Sort by combined length for efficient batching
|
||||||
def batch_index_generator():
|
def length_key(example: GRPOExample) -> int:
|
||||||
for i in range(0, len(idx) - batch_size + 1, batch_size):
|
return len(example.prompt_tokens) + len(example.answer_tokens)
|
||||||
yield idx[i : i + batch_size : step]
|
|
||||||
|
sorted_dataset = sorted(dataset, key=length_key)
|
||||||
|
|
||||||
|
# Create batch indices
|
||||||
|
num_complete_batches = (len(dataset) - batch_size + 1) // batch_size
|
||||||
|
batch_starts = range(0, num_complete_batches * batch_size, batch_size)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
indices = (
|
# Shuffle batch start indices
|
||||||
np.random.permutation(list(batch_index_generator())) if train
|
shuffled_starts = np.random.permutation(batch_starts)
|
||||||
else batch_index_generator()
|
|
||||||
)
|
|
||||||
|
|
||||||
for batch_idx in indices:
|
for start_idx in shuffled_starts:
|
||||||
current_batch = [dataset[j] for j in batch_idx]
|
# Account for distributed setup by taking every step-th example
|
||||||
|
batch_idx = list(range(start_idx, start_idx + batch_size, step))
|
||||||
|
current_batch = [sorted_dataset[j] for j in batch_idx]
|
||||||
|
|
||||||
prompts_tokens = [item[0] for item in current_batch]
|
# Create batch using dataclass attributes
|
||||||
answers_tokens = [item[1] for item in current_batch]
|
batch = GRPOBatch(
|
||||||
prompts_text = [item[2] for item in current_batch]
|
prompt_tokens=[ex.prompt_tokens for ex in current_batch],
|
||||||
answers_text = [item[3] for item in current_batch]
|
answer_tokens=[ex.answer_tokens for ex in current_batch],
|
||||||
|
prompt_texts=[ex.prompt_text for ex in current_batch],
|
||||||
if any(len(p) > max_seq_length for p in prompts_tokens):
|
answer_texts=[ex.answer_text for ex in current_batch]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check sequence lengths
|
||||||
|
if any(len(tokens) > max_seq_length for tokens in batch.prompt_tokens):
|
||||||
print(
|
print(
|
||||||
f"[WARNING] Some prompts are longer than {max_seq_length} tokens. "
|
f"[WARNING] Some prompts are longer than {max_seq_length} tokens. "
|
||||||
"Long prompts will be truncated."
|
"Long prompts will be truncated."
|
||||||
)
|
)
|
||||||
|
|
||||||
yield prompts_tokens, answers_tokens, prompts_text, answers_text
|
yield batch
|
||||||
|
|
||||||
if not train:
|
if not train:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
def evaluate_grpo(
|
def evaluate_grpo(
|
||||||
model,
|
model: nn.Module,
|
||||||
ref_model,
|
ref_model: Optional[nn.Module],
|
||||||
dataset,
|
dataset,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
batch_size,
|
batch_size,
|
||||||
@ -415,7 +428,6 @@ def evaluate_grpo(
|
|||||||
index_iterator,
|
index_iterator,
|
||||||
iterate_batches(
|
iterate_batches(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
tokenizer=tokenizer,
|
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
max_seq_length=max_seq_length,
|
max_seq_length=max_seq_length,
|
||||||
),
|
),
|
||||||
@ -459,12 +471,12 @@ def evaluate_grpo(
|
|||||||
|
|
||||||
|
|
||||||
def train_grpo(
|
def train_grpo(
|
||||||
model,
|
model: nn.Module,
|
||||||
ref_model,
|
ref_model: Optional[nn.Module],
|
||||||
tokenizer,
|
tokenizer,
|
||||||
optimizer,
|
optimizer,
|
||||||
train_dataset,
|
train_dataset: List[GRPOExample],
|
||||||
val_dataset,
|
val_dataset: List[GRPOExample],
|
||||||
reward_funcs = [
|
reward_funcs = [
|
||||||
r1_accuracy_reward_func,
|
r1_accuracy_reward_func,
|
||||||
r1_int_reward_func,
|
r1_int_reward_func,
|
||||||
@ -535,7 +547,6 @@ def train_grpo(
|
|||||||
range(1, args.iters + 1),
|
range(1, args.iters + 1),
|
||||||
iterate_batches(
|
iterate_batches(
|
||||||
dataset=train_dataset,
|
dataset=train_dataset,
|
||||||
tokenizer=tokenizer,
|
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
max_seq_length=args.max_seq_length,
|
max_seq_length=args.max_seq_length,
|
||||||
train=True,
|
train=True,
|
||||||
|
@ -2,7 +2,8 @@
|
|||||||
import json
|
import json
|
||||||
import types
|
import types
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict
|
from typing import Dict, List
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -275,3 +276,19 @@ def print_trainable_parameters(model):
|
|||||||
f"Trainable parameters: {(trainable_p * 100 / total_p):.3f}% "
|
f"Trainable parameters: {(trainable_p * 100 / total_p):.3f}% "
|
||||||
f"({trainable_p:.3f}M/{total_p:.3f}M)"
|
f"({trainable_p:.3f}M/{total_p:.3f}M)"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GRPOExample:
|
||||||
|
"""Single example for GRPO training/inference."""
|
||||||
|
prompt_tokens: List[int]
|
||||||
|
answer_tokens: List[int]
|
||||||
|
prompt_text: str
|
||||||
|
answer_text: str
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GRPOBatch:
|
||||||
|
"""A batch of GRPO examples."""
|
||||||
|
prompt_tokens: List[List[int]]
|
||||||
|
answer_tokens: List[List[int]]
|
||||||
|
prompt_texts: List[str]
|
||||||
|
answer_texts: List[str]
|
Loading…
Reference in New Issue
Block a user