mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-07-04 15:56:16 +08:00
updates
This commit is contained in:
parent
c42e858d7e
commit
e33d9d509b
@ -2,9 +2,8 @@ import itertools
|
||||
import json
|
||||
import types
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from .utils import GRPOExample
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
|
||||
@ -12,7 +11,7 @@ class GRPODataset:
|
||||
"""
|
||||
Dataset wrapper for GRPO training data.
|
||||
Each example should have a 'prompt' and 'answer' field.
|
||||
Returns data as GRPOExample instances.
|
||||
Returns data in (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple format.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
@ -23,40 +22,33 @@ class GRPODataset:
|
||||
use_chat_template: bool = False,
|
||||
use_prompt: bool = False
|
||||
):
|
||||
self._data: List[GRPOExample] = []
|
||||
self._data = []
|
||||
for item in data:
|
||||
prompt_str = str(item[prompt_key])
|
||||
answer_str = str(item[answer_key])
|
||||
|
||||
if use_chat_template:
|
||||
prompt_tokens = tokenizer.apply_chat_template(
|
||||
[
|
||||
{'role': 'system', 'content': """A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
|
||||
The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer.
|
||||
The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>."""},
|
||||
{'role': 'user', 'content': prompt_str}
|
||||
The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer.
|
||||
The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>."""},
|
||||
{'role': 'user', 'content': prompt_str}
|
||||
],
|
||||
)
|
||||
answer_tokens = tokenizer.encode(answer_str)
|
||||
else:
|
||||
if use_prompt:
|
||||
prompt_tokens = tokenizer.encode(f"""A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
|
||||
The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer.
|
||||
The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>.
|
||||
User: {prompt_str} Assistant: """)
|
||||
The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer.
|
||||
The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>.
|
||||
User: {prompt_str} Assistant: """)
|
||||
else:
|
||||
prompt_tokens = tokenizer.encode(prompt_str)
|
||||
answer_tokens = tokenizer.encode(answer_str)
|
||||
self._data.append((prompt_tokens, answer_tokens, prompt_str, answer_str))
|
||||
|
||||
self._data.append(GRPOExample(
|
||||
prompt_tokens=prompt_tokens,
|
||||
answer_tokens=answer_tokens,
|
||||
prompt_text=prompt_str,
|
||||
answer_text=answer_str
|
||||
))
|
||||
|
||||
def __getitem__(self, idx: int) -> GRPOExample:
|
||||
"""Returns a GRPOExample instance."""
|
||||
def __getitem__(self, idx: int) -> Tuple[List[int], List[int], str, str]:
|
||||
"""Returns a (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple."""
|
||||
return self._data[idx]
|
||||
|
||||
def __len__(self) -> int:
|
||||
@ -318,7 +310,7 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer):
|
||||
train, valid, test = load_local_dataset(args, data_path, tokenizer, args)
|
||||
else:
|
||||
print(f"Loading Hugging Face dataset {args.data}.")
|
||||
train, valid, test = load_hf_dataset(args, args.data, tokenizer, args)
|
||||
train, valid, test = load_hf_dataset(args.data, tokenizer, args)
|
||||
|
||||
if args.train and len(train) == 0:
|
||||
raise ValueError(
|
||||
|
@ -1,18 +1,18 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import time
|
||||
from typing import List, Optional, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Iterator, Optional
|
||||
from pathlib import Path
|
||||
import time
|
||||
import re
|
||||
|
||||
from mlx.utils import tree_flatten
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from mlx.utils import tree_flatten
|
||||
|
||||
from .utils import GRPOBatch, GRPOExample
|
||||
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients, iterate_batches
|
||||
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients
|
||||
|
||||
|
||||
@dataclass
|
||||
class GRPOTrainingArgs(TrainingArgs):
|
||||
@ -37,6 +37,9 @@ class GRPOTrainingArgs(TrainingArgs):
|
||||
)
|
||||
|
||||
|
||||
RewardFunctions = Callable[[List[str], List[str], List[str]], List[float]]
|
||||
|
||||
|
||||
def r1_extract_xml_answer(text: str) -> str:
|
||||
"""Extracts the answer from an XML formatted text string."""
|
||||
try:
|
||||
@ -180,10 +183,10 @@ def get_per_token_logps(model: nn.Module, inputs, lengths):
|
||||
|
||||
|
||||
def grpo_loss(
|
||||
model: nn.Module,
|
||||
ref_model: Optional[nn.Module],
|
||||
model,
|
||||
ref_model,
|
||||
tokenizer,
|
||||
batch=GRPOBatch,
|
||||
batch,
|
||||
reward_funcs=None,
|
||||
beta=0.1,
|
||||
group_size=4,
|
||||
@ -191,18 +194,14 @@ def grpo_loss(
|
||||
max_tokens=64,
|
||||
temperature=1.0
|
||||
):
|
||||
prompts_tokens = batch.prompt_tokens
|
||||
answers_tokens = batch.answer_tokens
|
||||
prompts_text = batch.prompt_texts
|
||||
answers_text = batch.answer_texts
|
||||
batch_size = len(prompts_tokens)
|
||||
prompt_tokens, answer_tokens, prompt_text, answer_text = batch
|
||||
batch_size = len(prompt_tokens)
|
||||
|
||||
# Generation logic remains the same
|
||||
all_completions = []
|
||||
all_completion_texts = []
|
||||
|
||||
for i in range(0, batch_size, batch_size):
|
||||
batch_prompts = prompts_tokens[i:i+batch_size]
|
||||
batch_prompts = prompt_tokens[i:i+batch_size]
|
||||
for prompt in batch_prompts:
|
||||
prompt_tensor = mx.array(prompt)
|
||||
for _ in range(group_size):
|
||||
@ -212,8 +211,6 @@ def grpo_loss(
|
||||
completion_text = tokenizer.decode(completion_ids.tolist())
|
||||
all_completions.append(completion_ids)
|
||||
all_completion_texts.append(completion_text)
|
||||
|
||||
# Clear completion tensors
|
||||
mx.eval(completion_ids)
|
||||
del completion_ids
|
||||
except Exception as e:
|
||||
@ -222,12 +219,11 @@ def grpo_loss(
|
||||
|
||||
mx.metal.clear_cache()
|
||||
|
||||
# Prepare inputs
|
||||
expanded_answers = []
|
||||
expanded_prompts = []
|
||||
for i in range(batch_size):
|
||||
expanded_answers.extend([answers_text[i]] * group_size)
|
||||
expanded_prompts.extend([prompts_text[i]] * group_size)
|
||||
expanded_answers.extend([answer_text[i]] * group_size)
|
||||
expanded_prompts.extend([prompt_text[i]] * group_size)
|
||||
|
||||
max_length = max(ids.shape[0] for ids in all_completions)
|
||||
padded_completions = []
|
||||
@ -260,6 +256,8 @@ def grpo_loss(
|
||||
ref_token_log_probs = token_log_probs
|
||||
else:
|
||||
ref_token_log_probs = get_per_token_logps(ref_model, inputs, lengths)
|
||||
mx.eval(ref_token_log_probs)
|
||||
mx.metal.clear_cache()
|
||||
|
||||
max_len = max(x.shape[0] for x in token_log_probs)
|
||||
padded_log_probs = []
|
||||
@ -275,7 +273,7 @@ def grpo_loss(
|
||||
token_log_probs = mx.stack(padded_log_probs)
|
||||
ref_token_log_probs = mx.stack(padded_ref_log_probs)
|
||||
|
||||
# Calculate rewards and advantages
|
||||
# Rewards and advantages
|
||||
rewards = mx.zeros((len(all_completions),))
|
||||
for reward_func in reward_funcs:
|
||||
func_rewards = mx.array(reward_func(
|
||||
@ -288,7 +286,7 @@ def grpo_loss(
|
||||
if len(reward_funcs) > 1:
|
||||
rewards /= len(reward_funcs)
|
||||
|
||||
# Reshape rewards and compute advantages following GRPO formula
|
||||
# Reshape rewards and compute advantages
|
||||
rewards_reshaped = rewards.reshape(batch_size, group_size)
|
||||
mean_rewards = mx.broadcast_to(mx.mean(rewards_reshaped, axis=1)[:, None], (rewards_reshaped.shape[0], group_size)).reshape(-1)
|
||||
std_rewards = mx.broadcast_to(mx.std(rewards_reshaped, axis=1)[:, None], (rewards_reshaped.shape[0], group_size)).reshape(-1)
|
||||
@ -303,7 +301,7 @@ def grpo_loss(
|
||||
# Compute policy ratio
|
||||
policy_ratio = mx.exp(mx.array(token_log_probs - mx.stop_gradient(ref_token_log_probs)))
|
||||
|
||||
# Compute per-token loss following GRPO formula
|
||||
# Compute per-token loss
|
||||
per_token_loss = -((policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask)
|
||||
|
||||
# Average over tokens
|
||||
@ -339,58 +337,50 @@ def grpo_loss(
|
||||
return loss, sequence_lengths.sum(), metrics
|
||||
|
||||
|
||||
def iterate_grpo_batches(
|
||||
dataset: List[GRPOExample],
|
||||
batch_size: int,
|
||||
max_seq_length: int,
|
||||
train: bool = False,
|
||||
) -> Iterator[GRPOBatch]:
|
||||
def iterate_grpo_batches(dataset, batch_size, max_seq_length, train=False):
|
||||
if not dataset or not isinstance(dataset[0], tuple) or len(dataset[0]) != 4:
|
||||
raise ValueError("Dataset must be list of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples")
|
||||
|
||||
def length_key(i):
|
||||
return len(dataset[i][0]) + len(dataset[i][1])
|
||||
|
||||
idx = sorted(range(len(dataset)), key=length_key)
|
||||
|
||||
if len(dataset) < batch_size:
|
||||
raise ValueError(
|
||||
f"Dataset must have at least batch_size={batch_size} "
|
||||
f"examples but only has {len(dataset)}."
|
||||
)
|
||||
|
||||
# Get MLX distributed setup
|
||||
step = mx.distributed.init().size()
|
||||
if batch_size % step != 0:
|
||||
raise ValueError("The batch size must be divisible by the number of workers")
|
||||
|
||||
# Sort by combined length for efficient batching
|
||||
def length_key(example: GRPOExample) -> int:
|
||||
return len(example.prompt_tokens) + len(example.answer_tokens)
|
||||
|
||||
sorted_dataset = sorted(dataset, key=length_key)
|
||||
|
||||
# Create batch indices
|
||||
num_complete_batches = (len(dataset) - batch_size + 1) // batch_size
|
||||
batch_starts = range(0, num_complete_batches * batch_size, batch_size)
|
||||
def batch_index_generator():
|
||||
for i in range(0, len(idx) - batch_size + 1, batch_size):
|
||||
yield idx[i : i + batch_size : step]
|
||||
|
||||
while True:
|
||||
# Shuffle batch start indices
|
||||
shuffled_starts = np.random.permutation(batch_starts)
|
||||
indices = (
|
||||
np.random.permutation(list(batch_index_generator())) if train
|
||||
else batch_index_generator()
|
||||
)
|
||||
|
||||
for start_idx in shuffled_starts:
|
||||
# Account for distributed setup by taking every step-th example
|
||||
batch_idx = list(range(start_idx, start_idx + batch_size, step))
|
||||
current_batch = [sorted_dataset[j] for j in batch_idx]
|
||||
for batch_idx in indices:
|
||||
current_batch = [dataset[j] for j in batch_idx]
|
||||
|
||||
# Create batch using dataclass attributes
|
||||
batch = GRPOBatch(
|
||||
prompt_tokens=[ex.prompt_tokens for ex in current_batch],
|
||||
answer_tokens=[ex.answer_tokens for ex in current_batch],
|
||||
prompt_texts=[ex.prompt_text for ex in current_batch],
|
||||
answer_texts=[ex.answer_text for ex in current_batch]
|
||||
)
|
||||
prompts_tokens = [item[0] for item in current_batch]
|
||||
answers_tokens = [item[1] for item in current_batch]
|
||||
prompts_text = [item[2] for item in current_batch]
|
||||
answers_text = [item[3] for item in current_batch]
|
||||
|
||||
# Check sequence lengths
|
||||
if any(len(tokens) > max_seq_length for tokens in batch.prompt_tokens):
|
||||
if any(len(p) > max_seq_length for p in prompts_tokens):
|
||||
print(
|
||||
f"[WARNING] Some prompts are longer than {max_seq_length} tokens. "
|
||||
"Long prompts will be truncated."
|
||||
)
|
||||
|
||||
yield batch
|
||||
yield prompts_tokens, answers_tokens, prompts_text, answers_text
|
||||
|
||||
if not train:
|
||||
break
|
||||
@ -407,7 +397,7 @@ def evaluate_grpo(
|
||||
epsilon: float,
|
||||
group_size: int,
|
||||
max_seq_length,
|
||||
reward_funcs = None,
|
||||
reward_funcs: Optional[List[RewardFunctions]] = None,
|
||||
loss_fn: callable = grpo_loss,
|
||||
iterate_batches: callable = iterate_grpo_batches
|
||||
):
|
||||
@ -418,12 +408,10 @@ def evaluate_grpo(
|
||||
"""
|
||||
all_losses = 0
|
||||
ntokens = 0
|
||||
all_metrics = None # Initialize metrics dictionary
|
||||
all_metrics = None
|
||||
|
||||
# Create iterator for batches
|
||||
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
||||
|
||||
# Iterate through batches
|
||||
for _, batch in zip(
|
||||
index_iterator,
|
||||
iterate_batches(
|
||||
@ -432,7 +420,6 @@ def evaluate_grpo(
|
||||
max_seq_length=max_seq_length,
|
||||
),
|
||||
):
|
||||
# Calculate loss for current batch
|
||||
losses, toks, metrics = loss_fn(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
@ -444,18 +431,15 @@ def evaluate_grpo(
|
||||
ref_model=ref_model
|
||||
)
|
||||
|
||||
# Accumulate losses and tokens
|
||||
all_losses += losses * toks
|
||||
ntokens += toks
|
||||
|
||||
# Accumulate metrics
|
||||
if all_metrics is None:
|
||||
all_metrics = {k: v * toks for k, v in metrics.items()}
|
||||
else:
|
||||
for k, v in metrics.items():
|
||||
all_metrics[k] += v * toks
|
||||
|
||||
# Evaluate accumulated values
|
||||
mx.eval(all_losses, ntokens)
|
||||
|
||||
# Aggregate across distributed workers
|
||||
@ -475,9 +459,9 @@ def train_grpo(
|
||||
ref_model: Optional[nn.Module],
|
||||
tokenizer,
|
||||
optimizer,
|
||||
train_dataset: List[GRPOExample],
|
||||
val_dataset: List[GRPOExample],
|
||||
reward_funcs = [
|
||||
train_dataset,
|
||||
val_dataset,
|
||||
reward_funcs: Optional[List[RewardFunctions]] = [
|
||||
r1_accuracy_reward_func,
|
||||
r1_int_reward_func,
|
||||
r1_strict_format_reward_func,
|
||||
|
@ -2,8 +2,7 @@
|
||||
import json
|
||||
import types
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@ -276,19 +275,3 @@ def print_trainable_parameters(model):
|
||||
f"Trainable parameters: {(trainable_p * 100 / total_p):.3f}% "
|
||||
f"({trainable_p:.3f}M/{total_p:.3f}M)"
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class GRPOExample:
|
||||
"""Single example for GRPO training/inference."""
|
||||
prompt_tokens: List[int]
|
||||
answer_tokens: List[int]
|
||||
prompt_text: str
|
||||
answer_text: str
|
||||
|
||||
@dataclass
|
||||
class GRPOBatch:
|
||||
"""A batch of GRPO examples."""
|
||||
prompt_tokens: List[List[int]]
|
||||
answer_tokens: List[List[int]]
|
||||
prompt_texts: List[str]
|
||||
answer_texts: List[str]
|
Loading…
Reference in New Issue
Block a user