mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-07-06 00:31:13 +08:00
updates
This commit is contained in:
parent
c42e858d7e
commit
e33d9d509b
@ -2,9 +2,8 @@ import itertools
|
|||||||
import json
|
import json
|
||||||
import types
|
import types
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
from .utils import GRPOExample
|
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
|
|
||||||
@ -12,7 +11,7 @@ class GRPODataset:
|
|||||||
"""
|
"""
|
||||||
Dataset wrapper for GRPO training data.
|
Dataset wrapper for GRPO training data.
|
||||||
Each example should have a 'prompt' and 'answer' field.
|
Each example should have a 'prompt' and 'answer' field.
|
||||||
Returns data as GRPOExample instances.
|
Returns data in (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple format.
|
||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -23,11 +22,10 @@ class GRPODataset:
|
|||||||
use_chat_template: bool = False,
|
use_chat_template: bool = False,
|
||||||
use_prompt: bool = False
|
use_prompt: bool = False
|
||||||
):
|
):
|
||||||
self._data: List[GRPOExample] = []
|
self._data = []
|
||||||
for item in data:
|
for item in data:
|
||||||
prompt_str = str(item[prompt_key])
|
prompt_str = str(item[prompt_key])
|
||||||
answer_str = str(item[answer_key])
|
answer_str = str(item[answer_key])
|
||||||
|
|
||||||
if use_chat_template:
|
if use_chat_template:
|
||||||
prompt_tokens = tokenizer.apply_chat_template(
|
prompt_tokens = tokenizer.apply_chat_template(
|
||||||
[
|
[
|
||||||
@ -47,16 +45,10 @@ class GRPODataset:
|
|||||||
else:
|
else:
|
||||||
prompt_tokens = tokenizer.encode(prompt_str)
|
prompt_tokens = tokenizer.encode(prompt_str)
|
||||||
answer_tokens = tokenizer.encode(answer_str)
|
answer_tokens = tokenizer.encode(answer_str)
|
||||||
|
self._data.append((prompt_tokens, answer_tokens, prompt_str, answer_str))
|
||||||
|
|
||||||
self._data.append(GRPOExample(
|
def __getitem__(self, idx: int) -> Tuple[List[int], List[int], str, str]:
|
||||||
prompt_tokens=prompt_tokens,
|
"""Returns a (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple."""
|
||||||
answer_tokens=answer_tokens,
|
|
||||||
prompt_text=prompt_str,
|
|
||||||
answer_text=answer_str
|
|
||||||
))
|
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> GRPOExample:
|
|
||||||
"""Returns a GRPOExample instance."""
|
|
||||||
return self._data[idx]
|
return self._data[idx]
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
@ -318,7 +310,7 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer):
|
|||||||
train, valid, test = load_local_dataset(args, data_path, tokenizer, args)
|
train, valid, test = load_local_dataset(args, data_path, tokenizer, args)
|
||||||
else:
|
else:
|
||||||
print(f"Loading Hugging Face dataset {args.data}.")
|
print(f"Loading Hugging Face dataset {args.data}.")
|
||||||
train, valid, test = load_hf_dataset(args, args.data, tokenizer, args)
|
train, valid, test = load_hf_dataset(args.data, tokenizer, args)
|
||||||
|
|
||||||
if args.train and len(train) == 0:
|
if args.train and len(train) == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -1,18 +1,18 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
import time
|
from typing import List, Optional, Callable
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import List, Iterator, Optional
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import time
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
from mlx.utils import tree_flatten
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mlx.utils import tree_flatten
|
|
||||||
|
|
||||||
from .utils import GRPOBatch, GRPOExample
|
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients
|
||||||
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients, iterate_batches
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GRPOTrainingArgs(TrainingArgs):
|
class GRPOTrainingArgs(TrainingArgs):
|
||||||
@ -37,6 +37,9 @@ class GRPOTrainingArgs(TrainingArgs):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
RewardFunctions = Callable[[List[str], List[str], List[str]], List[float]]
|
||||||
|
|
||||||
|
|
||||||
def r1_extract_xml_answer(text: str) -> str:
|
def r1_extract_xml_answer(text: str) -> str:
|
||||||
"""Extracts the answer from an XML formatted text string."""
|
"""Extracts the answer from an XML formatted text string."""
|
||||||
try:
|
try:
|
||||||
@ -180,10 +183,10 @@ def get_per_token_logps(model: nn.Module, inputs, lengths):
|
|||||||
|
|
||||||
|
|
||||||
def grpo_loss(
|
def grpo_loss(
|
||||||
model: nn.Module,
|
model,
|
||||||
ref_model: Optional[nn.Module],
|
ref_model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
batch=GRPOBatch,
|
batch,
|
||||||
reward_funcs=None,
|
reward_funcs=None,
|
||||||
beta=0.1,
|
beta=0.1,
|
||||||
group_size=4,
|
group_size=4,
|
||||||
@ -191,18 +194,14 @@ def grpo_loss(
|
|||||||
max_tokens=64,
|
max_tokens=64,
|
||||||
temperature=1.0
|
temperature=1.0
|
||||||
):
|
):
|
||||||
prompts_tokens = batch.prompt_tokens
|
prompt_tokens, answer_tokens, prompt_text, answer_text = batch
|
||||||
answers_tokens = batch.answer_tokens
|
batch_size = len(prompt_tokens)
|
||||||
prompts_text = batch.prompt_texts
|
|
||||||
answers_text = batch.answer_texts
|
|
||||||
batch_size = len(prompts_tokens)
|
|
||||||
|
|
||||||
# Generation logic remains the same
|
|
||||||
all_completions = []
|
all_completions = []
|
||||||
all_completion_texts = []
|
all_completion_texts = []
|
||||||
|
|
||||||
for i in range(0, batch_size, batch_size):
|
for i in range(0, batch_size, batch_size):
|
||||||
batch_prompts = prompts_tokens[i:i+batch_size]
|
batch_prompts = prompt_tokens[i:i+batch_size]
|
||||||
for prompt in batch_prompts:
|
for prompt in batch_prompts:
|
||||||
prompt_tensor = mx.array(prompt)
|
prompt_tensor = mx.array(prompt)
|
||||||
for _ in range(group_size):
|
for _ in range(group_size):
|
||||||
@ -212,8 +211,6 @@ def grpo_loss(
|
|||||||
completion_text = tokenizer.decode(completion_ids.tolist())
|
completion_text = tokenizer.decode(completion_ids.tolist())
|
||||||
all_completions.append(completion_ids)
|
all_completions.append(completion_ids)
|
||||||
all_completion_texts.append(completion_text)
|
all_completion_texts.append(completion_text)
|
||||||
|
|
||||||
# Clear completion tensors
|
|
||||||
mx.eval(completion_ids)
|
mx.eval(completion_ids)
|
||||||
del completion_ids
|
del completion_ids
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -222,12 +219,11 @@ def grpo_loss(
|
|||||||
|
|
||||||
mx.metal.clear_cache()
|
mx.metal.clear_cache()
|
||||||
|
|
||||||
# Prepare inputs
|
|
||||||
expanded_answers = []
|
expanded_answers = []
|
||||||
expanded_prompts = []
|
expanded_prompts = []
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
expanded_answers.extend([answers_text[i]] * group_size)
|
expanded_answers.extend([answer_text[i]] * group_size)
|
||||||
expanded_prompts.extend([prompts_text[i]] * group_size)
|
expanded_prompts.extend([prompt_text[i]] * group_size)
|
||||||
|
|
||||||
max_length = max(ids.shape[0] for ids in all_completions)
|
max_length = max(ids.shape[0] for ids in all_completions)
|
||||||
padded_completions = []
|
padded_completions = []
|
||||||
@ -260,6 +256,8 @@ def grpo_loss(
|
|||||||
ref_token_log_probs = token_log_probs
|
ref_token_log_probs = token_log_probs
|
||||||
else:
|
else:
|
||||||
ref_token_log_probs = get_per_token_logps(ref_model, inputs, lengths)
|
ref_token_log_probs = get_per_token_logps(ref_model, inputs, lengths)
|
||||||
|
mx.eval(ref_token_log_probs)
|
||||||
|
mx.metal.clear_cache()
|
||||||
|
|
||||||
max_len = max(x.shape[0] for x in token_log_probs)
|
max_len = max(x.shape[0] for x in token_log_probs)
|
||||||
padded_log_probs = []
|
padded_log_probs = []
|
||||||
@ -275,7 +273,7 @@ def grpo_loss(
|
|||||||
token_log_probs = mx.stack(padded_log_probs)
|
token_log_probs = mx.stack(padded_log_probs)
|
||||||
ref_token_log_probs = mx.stack(padded_ref_log_probs)
|
ref_token_log_probs = mx.stack(padded_ref_log_probs)
|
||||||
|
|
||||||
# Calculate rewards and advantages
|
# Rewards and advantages
|
||||||
rewards = mx.zeros((len(all_completions),))
|
rewards = mx.zeros((len(all_completions),))
|
||||||
for reward_func in reward_funcs:
|
for reward_func in reward_funcs:
|
||||||
func_rewards = mx.array(reward_func(
|
func_rewards = mx.array(reward_func(
|
||||||
@ -288,7 +286,7 @@ def grpo_loss(
|
|||||||
if len(reward_funcs) > 1:
|
if len(reward_funcs) > 1:
|
||||||
rewards /= len(reward_funcs)
|
rewards /= len(reward_funcs)
|
||||||
|
|
||||||
# Reshape rewards and compute advantages following GRPO formula
|
# Reshape rewards and compute advantages
|
||||||
rewards_reshaped = rewards.reshape(batch_size, group_size)
|
rewards_reshaped = rewards.reshape(batch_size, group_size)
|
||||||
mean_rewards = mx.broadcast_to(mx.mean(rewards_reshaped, axis=1)[:, None], (rewards_reshaped.shape[0], group_size)).reshape(-1)
|
mean_rewards = mx.broadcast_to(mx.mean(rewards_reshaped, axis=1)[:, None], (rewards_reshaped.shape[0], group_size)).reshape(-1)
|
||||||
std_rewards = mx.broadcast_to(mx.std(rewards_reshaped, axis=1)[:, None], (rewards_reshaped.shape[0], group_size)).reshape(-1)
|
std_rewards = mx.broadcast_to(mx.std(rewards_reshaped, axis=1)[:, None], (rewards_reshaped.shape[0], group_size)).reshape(-1)
|
||||||
@ -303,7 +301,7 @@ def grpo_loss(
|
|||||||
# Compute policy ratio
|
# Compute policy ratio
|
||||||
policy_ratio = mx.exp(mx.array(token_log_probs - mx.stop_gradient(ref_token_log_probs)))
|
policy_ratio = mx.exp(mx.array(token_log_probs - mx.stop_gradient(ref_token_log_probs)))
|
||||||
|
|
||||||
# Compute per-token loss following GRPO formula
|
# Compute per-token loss
|
||||||
per_token_loss = -((policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask)
|
per_token_loss = -((policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask)
|
||||||
|
|
||||||
# Average over tokens
|
# Average over tokens
|
||||||
@ -339,58 +337,50 @@ def grpo_loss(
|
|||||||
return loss, sequence_lengths.sum(), metrics
|
return loss, sequence_lengths.sum(), metrics
|
||||||
|
|
||||||
|
|
||||||
def iterate_grpo_batches(
|
def iterate_grpo_batches(dataset, batch_size, max_seq_length, train=False):
|
||||||
dataset: List[GRPOExample],
|
if not dataset or not isinstance(dataset[0], tuple) or len(dataset[0]) != 4:
|
||||||
batch_size: int,
|
raise ValueError("Dataset must be list of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples")
|
||||||
max_seq_length: int,
|
|
||||||
train: bool = False,
|
def length_key(i):
|
||||||
) -> Iterator[GRPOBatch]:
|
return len(dataset[i][0]) + len(dataset[i][1])
|
||||||
|
|
||||||
|
idx = sorted(range(len(dataset)), key=length_key)
|
||||||
|
|
||||||
if len(dataset) < batch_size:
|
if len(dataset) < batch_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Dataset must have at least batch_size={batch_size} "
|
f"Dataset must have at least batch_size={batch_size} "
|
||||||
f"examples but only has {len(dataset)}."
|
f"examples but only has {len(dataset)}."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get MLX distributed setup
|
|
||||||
step = mx.distributed.init().size()
|
step = mx.distributed.init().size()
|
||||||
if batch_size % step != 0:
|
if batch_size % step != 0:
|
||||||
raise ValueError("The batch size must be divisible by the number of workers")
|
raise ValueError("The batch size must be divisible by the number of workers")
|
||||||
|
|
||||||
# Sort by combined length for efficient batching
|
def batch_index_generator():
|
||||||
def length_key(example: GRPOExample) -> int:
|
for i in range(0, len(idx) - batch_size + 1, batch_size):
|
||||||
return len(example.prompt_tokens) + len(example.answer_tokens)
|
yield idx[i : i + batch_size : step]
|
||||||
|
|
||||||
sorted_dataset = sorted(dataset, key=length_key)
|
|
||||||
|
|
||||||
# Create batch indices
|
|
||||||
num_complete_batches = (len(dataset) - batch_size + 1) // batch_size
|
|
||||||
batch_starts = range(0, num_complete_batches * batch_size, batch_size)
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
# Shuffle batch start indices
|
indices = (
|
||||||
shuffled_starts = np.random.permutation(batch_starts)
|
np.random.permutation(list(batch_index_generator())) if train
|
||||||
|
else batch_index_generator()
|
||||||
for start_idx in shuffled_starts:
|
|
||||||
# Account for distributed setup by taking every step-th example
|
|
||||||
batch_idx = list(range(start_idx, start_idx + batch_size, step))
|
|
||||||
current_batch = [sorted_dataset[j] for j in batch_idx]
|
|
||||||
|
|
||||||
# Create batch using dataclass attributes
|
|
||||||
batch = GRPOBatch(
|
|
||||||
prompt_tokens=[ex.prompt_tokens for ex in current_batch],
|
|
||||||
answer_tokens=[ex.answer_tokens for ex in current_batch],
|
|
||||||
prompt_texts=[ex.prompt_text for ex in current_batch],
|
|
||||||
answer_texts=[ex.answer_text for ex in current_batch]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check sequence lengths
|
for batch_idx in indices:
|
||||||
if any(len(tokens) > max_seq_length for tokens in batch.prompt_tokens):
|
current_batch = [dataset[j] for j in batch_idx]
|
||||||
|
|
||||||
|
prompts_tokens = [item[0] for item in current_batch]
|
||||||
|
answers_tokens = [item[1] for item in current_batch]
|
||||||
|
prompts_text = [item[2] for item in current_batch]
|
||||||
|
answers_text = [item[3] for item in current_batch]
|
||||||
|
|
||||||
|
if any(len(p) > max_seq_length for p in prompts_tokens):
|
||||||
print(
|
print(
|
||||||
f"[WARNING] Some prompts are longer than {max_seq_length} tokens. "
|
f"[WARNING] Some prompts are longer than {max_seq_length} tokens. "
|
||||||
"Long prompts will be truncated."
|
"Long prompts will be truncated."
|
||||||
)
|
)
|
||||||
|
|
||||||
yield batch
|
yield prompts_tokens, answers_tokens, prompts_text, answers_text
|
||||||
|
|
||||||
if not train:
|
if not train:
|
||||||
break
|
break
|
||||||
@ -407,7 +397,7 @@ def evaluate_grpo(
|
|||||||
epsilon: float,
|
epsilon: float,
|
||||||
group_size: int,
|
group_size: int,
|
||||||
max_seq_length,
|
max_seq_length,
|
||||||
reward_funcs = None,
|
reward_funcs: Optional[List[RewardFunctions]] = None,
|
||||||
loss_fn: callable = grpo_loss,
|
loss_fn: callable = grpo_loss,
|
||||||
iterate_batches: callable = iterate_grpo_batches
|
iterate_batches: callable = iterate_grpo_batches
|
||||||
):
|
):
|
||||||
@ -418,12 +408,10 @@ def evaluate_grpo(
|
|||||||
"""
|
"""
|
||||||
all_losses = 0
|
all_losses = 0
|
||||||
ntokens = 0
|
ntokens = 0
|
||||||
all_metrics = None # Initialize metrics dictionary
|
all_metrics = None
|
||||||
|
|
||||||
# Create iterator for batches
|
|
||||||
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
||||||
|
|
||||||
# Iterate through batches
|
|
||||||
for _, batch in zip(
|
for _, batch in zip(
|
||||||
index_iterator,
|
index_iterator,
|
||||||
iterate_batches(
|
iterate_batches(
|
||||||
@ -432,7 +420,6 @@ def evaluate_grpo(
|
|||||||
max_seq_length=max_seq_length,
|
max_seq_length=max_seq_length,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
# Calculate loss for current batch
|
|
||||||
losses, toks, metrics = loss_fn(
|
losses, toks, metrics = loss_fn(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
@ -444,18 +431,15 @@ def evaluate_grpo(
|
|||||||
ref_model=ref_model
|
ref_model=ref_model
|
||||||
)
|
)
|
||||||
|
|
||||||
# Accumulate losses and tokens
|
|
||||||
all_losses += losses * toks
|
all_losses += losses * toks
|
||||||
ntokens += toks
|
ntokens += toks
|
||||||
|
|
||||||
# Accumulate metrics
|
|
||||||
if all_metrics is None:
|
if all_metrics is None:
|
||||||
all_metrics = {k: v * toks for k, v in metrics.items()}
|
all_metrics = {k: v * toks for k, v in metrics.items()}
|
||||||
else:
|
else:
|
||||||
for k, v in metrics.items():
|
for k, v in metrics.items():
|
||||||
all_metrics[k] += v * toks
|
all_metrics[k] += v * toks
|
||||||
|
|
||||||
# Evaluate accumulated values
|
|
||||||
mx.eval(all_losses, ntokens)
|
mx.eval(all_losses, ntokens)
|
||||||
|
|
||||||
# Aggregate across distributed workers
|
# Aggregate across distributed workers
|
||||||
@ -475,9 +459,9 @@ def train_grpo(
|
|||||||
ref_model: Optional[nn.Module],
|
ref_model: Optional[nn.Module],
|
||||||
tokenizer,
|
tokenizer,
|
||||||
optimizer,
|
optimizer,
|
||||||
train_dataset: List[GRPOExample],
|
train_dataset,
|
||||||
val_dataset: List[GRPOExample],
|
val_dataset,
|
||||||
reward_funcs = [
|
reward_funcs: Optional[List[RewardFunctions]] = [
|
||||||
r1_accuracy_reward_func,
|
r1_accuracy_reward_func,
|
||||||
r1_int_reward_func,
|
r1_int_reward_func,
|
||||||
r1_strict_format_reward_func,
|
r1_strict_format_reward_func,
|
||||||
|
@ -2,8 +2,7 @@
|
|||||||
import json
|
import json
|
||||||
import types
|
import types
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List
|
from typing import Dict
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -276,19 +275,3 @@ def print_trainable_parameters(model):
|
|||||||
f"Trainable parameters: {(trainable_p * 100 / total_p):.3f}% "
|
f"Trainable parameters: {(trainable_p * 100 / total_p):.3f}% "
|
||||||
f"({trainable_p:.3f}M/{total_p:.3f}M)"
|
f"({trainable_p:.3f}M/{total_p:.3f}M)"
|
||||||
)
|
)
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class GRPOExample:
|
|
||||||
"""Single example for GRPO training/inference."""
|
|
||||||
prompt_tokens: List[int]
|
|
||||||
answer_tokens: List[int]
|
|
||||||
prompt_text: str
|
|
||||||
answer_text: str
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class GRPOBatch:
|
|
||||||
"""A batch of GRPO examples."""
|
|
||||||
prompt_tokens: List[List[int]]
|
|
||||||
answer_tokens: List[List[int]]
|
|
||||||
prompt_texts: List[str]
|
|
||||||
answer_texts: List[str]
|
|
Loading…
Reference in New Issue
Block a user