mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 02:33:23 +08:00
starting fist training test run
This commit is contained in:
parent
41ff5364d7
commit
23d75cd7ad
@ -174,6 +174,7 @@ def build_parser():
|
||||
)
|
||||
parser.add_argument("--seed", type=int, help="The PRNG seed")
|
||||
|
||||
# GRPO args
|
||||
parser.add_argument(
|
||||
"--group-size",
|
||||
type=int,
|
||||
@ -270,12 +271,13 @@ def train_model(
|
||||
|
||||
if args.reference_model_path:
|
||||
reference_model, _ = load(args.reference_model_path)
|
||||
reference_model = reference_model.freeze()
|
||||
else:
|
||||
reference_model, _ = load(args.model)
|
||||
reference_model, _ = None, None
|
||||
|
||||
train_grpo(
|
||||
model=model,
|
||||
reference_model=reference_model.freeze(),
|
||||
ref_model=reference_model,
|
||||
tokenizer=tokenizer,
|
||||
optimizer=opt,
|
||||
train_dataset=train_set,
|
||||
@ -318,7 +320,7 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set
|
||||
|
||||
test_loss, test_rewards = evaluate_grpo(
|
||||
model=model,
|
||||
reference_model=reference_model,
|
||||
ref_model=reference_model,
|
||||
dataset=test_set,
|
||||
tokenizer=tokenizer,
|
||||
batch_size=args.batch_size,
|
||||
@ -326,8 +328,7 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set
|
||||
max_seq_length=args.max_seq_length,
|
||||
beta=args.beta,
|
||||
group_size=args.group_size,
|
||||
epsilon=args.epsilon,
|
||||
reference_model_path=args.reference_model_path
|
||||
epsilon=args.epsilon
|
||||
)
|
||||
print(f"Test loss {test_loss:.3f}, Rewards: {test_rewards[0]:.3f}, {test_rewards[1]:.3f}")
|
||||
else:
|
||||
|
@ -1,6 +1,6 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
@ -9,36 +9,30 @@ class GRPODataset:
|
||||
"""
|
||||
Dataset wrapper for GRPO training data.
|
||||
Each example should have a 'prompt' and 'answer' field.
|
||||
Returns data in (prompt, answer) tuple format required by GRPO trainer.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
data: List[Dict[str, str]],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
prompt_key: str = "prompt",
|
||||
prompt_key: str = "prompt",
|
||||
answer_key: str = "answer"
|
||||
):
|
||||
self._data = []
|
||||
|
||||
for item in data:
|
||||
# Tokenize prompt and answer
|
||||
prompt_tokens = tokenizer.encode(item[prompt_key])
|
||||
answer_tokens = tokenizer.encode(item[answer_key])
|
||||
# Get prompt and answer text
|
||||
prompt = str(item[prompt_key])
|
||||
answer = str(item[answer_key])
|
||||
|
||||
# Add EOS tokens if needed
|
||||
if prompt_tokens[-1] != tokenizer.eos_token_id:
|
||||
prompt_tokens.append(tokenizer.eos_token_id)
|
||||
if answer_tokens[-1] != tokenizer.eos_token_id:
|
||||
answer_tokens.append(tokenizer.eos_token_id)
|
||||
|
||||
self._data.append({
|
||||
'prompt': prompt_tokens,
|
||||
'answer': answer_tokens
|
||||
})
|
||||
# Store as (prompt, answer) tuple
|
||||
self._data.append((prompt, answer))
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict[str, List[int]]:
|
||||
def __getitem__(self, idx: int) -> Tuple[str, str]:
|
||||
"""Returns a (prompt, answer) tuple for the given index."""
|
||||
return self._data[idx]
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Returns the number of examples in the dataset."""
|
||||
return len(self._data)
|
||||
|
||||
|
||||
@ -127,8 +121,11 @@ def create_dataset(
|
||||
prompt_feature = prompt_feature or "prompt"
|
||||
completion_feature = completion_feature or "completion"
|
||||
sample = data[0]
|
||||
|
||||
if "messages" in sample:
|
||||
return ChatDataset(data, tokenizer)
|
||||
elif "prompt" in sample and "answer" in sample:
|
||||
return GRPODataset(data, tokenizer, "prompt", "answer") # Use GRPO Dataset
|
||||
elif prompt_feature in sample and completion_feature in sample:
|
||||
return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature)
|
||||
elif "text" in sample:
|
||||
|
@ -10,11 +10,10 @@ import mlx.nn as nn
|
||||
import numpy as np
|
||||
from mlx.utils import tree_flatten
|
||||
|
||||
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients
|
||||
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients, iterate_batches
|
||||
|
||||
from mlx_lm import generate
|
||||
|
||||
generate()
|
||||
|
||||
@dataclass
|
||||
class GRPOTrainingArgs(TrainingArgs):
|
||||
@ -263,55 +262,66 @@ def grpo_loss(
|
||||
return loss, sequence_lengths.sum(), metrics
|
||||
|
||||
|
||||
def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
||||
# Sort by length:
|
||||
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
|
||||
def iterate_grpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
||||
"""
|
||||
Creates batches from prompt-answer pairs for GRPO training.
|
||||
|
||||
Args:
|
||||
dataset: List of (prompt, answer) pairs
|
||||
tokenizer: Tokenizer for processing inputs
|
||||
batch_size: Size of each batch
|
||||
max_seq_length: Maximum sequence length
|
||||
train: Whether this is for training
|
||||
|
||||
Yields:
|
||||
List of prompts for the current batch
|
||||
"""
|
||||
# Verify dataset is not empty and has correct format
|
||||
if not dataset or not isinstance(dataset[0], (tuple, list)) or len(dataset[0]) != 2:
|
||||
raise ValueError("Dataset must be a list of (prompt, answer) pairs")
|
||||
|
||||
# Sort by combined length of prompt + answer
|
||||
idx = sorted(range(len(dataset)),
|
||||
key=lambda i: len(dataset[i][0]) + len(dataset[i][1]))
|
||||
|
||||
if len(dataset) < batch_size:
|
||||
raise ValueError(
|
||||
f"Dataset must have at least batch_size={batch_size}"
|
||||
f" examples but only has {len(dataset)}."
|
||||
f"Dataset must have at least batch_size={batch_size} "
|
||||
f"examples but only has {len(dataset)}."
|
||||
)
|
||||
|
||||
# If running in distributed mode (N machines) then each one should skip N-1
|
||||
# samples
|
||||
# Handle distributed training
|
||||
step = mx.distributed.init().size()
|
||||
if batch_size % step != 0:
|
||||
raise ValueError("The batch size must be divisible by the number of workers")
|
||||
|
||||
# Make the batches:
|
||||
# Create batch indices
|
||||
batch_idx = [
|
||||
idx[i : i + batch_size : step]
|
||||
for i in range(0, len(idx) - batch_size + 1, batch_size)
|
||||
]
|
||||
|
||||
while True:
|
||||
indices = np.random.permutation(len(batch_idx))
|
||||
# Shuffle batch indices if training
|
||||
indices = np.random.permutation(len(batch_idx)) if train else range(len(batch_idx))
|
||||
|
||||
for i in indices:
|
||||
batch = [dataset[j] for j in batch_idx[i]]
|
||||
lengths = [len(x) for x in batch]
|
||||
if max(lengths) > max_seq_length:
|
||||
# Get current batch of prompt-answer pairs
|
||||
current_batch = [dataset[j] for j in batch_idx[i]]
|
||||
|
||||
# Extract prompts and answers
|
||||
prompts = [pair[0] for pair in current_batch]
|
||||
answers = [pair[1] for pair in current_batch]
|
||||
|
||||
if any(len(p) > max_seq_length for p in prompts):
|
||||
print(
|
||||
f"[WARNING] Some sequences are longer than {max_seq_length} tokens. "
|
||||
f"The longest sentence {max(lengths)} will be truncated to {max_seq_length}. "
|
||||
"Consider pre-splitting your data to save memory."
|
||||
f"[WARNING] Some prompts are longer than {max_seq_length} tokens. "
|
||||
"Long prompts will be truncated."
|
||||
)
|
||||
|
||||
# Pad to the nearest multiple of 8 or the maximum length
|
||||
pad_to = 8
|
||||
max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to)
|
||||
max_length_in_batch = min(max_length_in_batch, max_seq_length)
|
||||
|
||||
batch_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
|
||||
|
||||
for j in range(batch_size // step):
|
||||
truncated_length = min(lengths[j], max_seq_length)
|
||||
batch_arr[j, :truncated_length] = batch[j][:truncated_length]
|
||||
lengths[j] = (
|
||||
truncated_length # Update lengths to match truncated lengths
|
||||
)
|
||||
batch = mx.array(batch_arr)
|
||||
|
||||
yield batch[:, :-1], batch[:, 1:], mx.array(lengths)
|
||||
|
||||
# For GRPO, we only need to yield the prompts
|
||||
# The answers will be used by the reward functions
|
||||
yield prompts
|
||||
|
||||
if not train:
|
||||
break
|
||||
@ -325,12 +335,12 @@ def evaluate_grpo(
|
||||
batch_size,
|
||||
num_batches,
|
||||
beta: float,
|
||||
epslion: float,
|
||||
epsilon: float,
|
||||
group_size: int,
|
||||
max_seq_length,
|
||||
reward_funcs = None,
|
||||
loss: callable = grpo_loss,
|
||||
iterate_batches: callable = iterate_batches
|
||||
iterate_batches: callable = iterate_grpo_batches
|
||||
):
|
||||
all_losses = 0
|
||||
ntokens = 0
|
||||
@ -354,7 +364,7 @@ def evaluate_grpo(
|
||||
reward_funcs=reward_funcs,
|
||||
beta=beta,
|
||||
group_size=group_size,
|
||||
epslion=epslion,
|
||||
epsilon=epsilon,
|
||||
ref_model=ref_model
|
||||
)
|
||||
all_losses += losses * toks
|
||||
@ -394,10 +404,10 @@ def train_grpo(
|
||||
],
|
||||
args: GRPOTrainingArgs = GRPOTrainingArgs(),
|
||||
loss: callable = grpo_loss,
|
||||
iterate_batches: callable = iterate_batches,
|
||||
iterate_batches: callable = iterate_grpo_batches,
|
||||
training_callback: TrainingCallback = None,
|
||||
):
|
||||
print(f"Starting GRPO training..., iters: {args.iters}")
|
||||
print(f"Starting GRPO training with {len(reward_funcs)} reward functions..., iters: {args.iters}")
|
||||
world = mx.distributed.init()
|
||||
world_size = world.size()
|
||||
rank = world.rank()
|
||||
@ -434,6 +444,9 @@ def train_grpo(
|
||||
'grouped_rewards_std': 0,
|
||||
'kl': 0
|
||||
}
|
||||
for i in range(len(reward_funcs)):
|
||||
accumulated_metrics[f'reward_func_{i}_mean'] = 0
|
||||
accumulated_metrics[f'reward_func_{i}_std'] = 0
|
||||
|
||||
start = time.perf_counter()
|
||||
for it, batch in zip(
|
||||
@ -454,26 +467,37 @@ def train_grpo(
|
||||
model=model,
|
||||
dataset=val_dataset,
|
||||
loss=loss,
|
||||
|
||||
ref_model=model,
|
||||
ref_model=ref_model,
|
||||
reward_funcs=reward_funcs,
|
||||
|
||||
tokenizer=tokenizer,
|
||||
group_size=args.group_size,
|
||||
batch_size=args.batch_size,
|
||||
num_batches=args.val_batches,
|
||||
max_seq_length=args.max_seq_length,
|
||||
beta=args.beta,
|
||||
epsilon=args.epsilon,
|
||||
iterate_batches=iterate_batches,
|
||||
)
|
||||
val_time = time.perf_counter() - stop
|
||||
if rank == 0:
|
||||
print(
|
||||
f"Iter {it}: "
|
||||
val_metrics_str = (
|
||||
f"Val loss {val_loss:.8f}, "
|
||||
f"Val rewards {val_metrics['rewards']:.3f}, "
|
||||
f"Val rewards_std {val_metrics['rewards_std']:.3f}, "
|
||||
f"Val grouped_rewards {val_metrics['grouped_rewards']:.3f}, "
|
||||
f"Val total_rewards_mean {val_metrics['total_rewards_mean']:.3f}, "
|
||||
f"Val total_rewards_std {val_metrics['total_rewards_std']:.3f}, "
|
||||
f"Val grouped_rewards_mean {val_metrics['grouped_rewards_mean']:.3f}, "
|
||||
f"Val grouped_rewards_std {val_metrics['grouped_rewards_std']:.3f}, "
|
||||
f"Val kl {val_metrics['kl']:.3f}, "
|
||||
f"Val kl {val_metrics['kl']:.3f}"
|
||||
)
|
||||
|
||||
# Add reward function specific metrics
|
||||
for i in range(len(reward_funcs)):
|
||||
val_metrics_str += (
|
||||
f", Val reward_func_{i}_mean {val_metrics[f'reward_func_{i}_mean']:.3f}, "
|
||||
f"Val reward_func_{i}_std {val_metrics[f'reward_func_{i}_std']:.3f}"
|
||||
)
|
||||
|
||||
print(
|
||||
f"Iter {it}: {val_metrics_str}, "
|
||||
f"Val took {val_time:.3f}s",
|
||||
flush=True,
|
||||
)
|
||||
@ -510,14 +534,24 @@ def train_grpo(
|
||||
peak_mem = mx.metal.get_peak_memory() / 1e9
|
||||
|
||||
if rank == 0:
|
||||
train_metrics_str = (
|
||||
f"Train loss {train_loss:.8f}, "
|
||||
f"Total rewards mean {avg_metrics['total_rewards_mean']:.3f}, "
|
||||
f"Total rewards std {avg_metrics['total_rewards_std']:.3f}, "
|
||||
f"Grouped rewards mean {avg_metrics['grouped_rewards_mean']:.3f}, "
|
||||
f"Grouped rewards std {avg_metrics['grouped_rewards_std']:.3f}, "
|
||||
f"KL {avg_metrics['kl']:.3f}"
|
||||
)
|
||||
|
||||
# Add reward function specific metrics
|
||||
for i in range(len(reward_funcs)):
|
||||
train_metrics_str += (
|
||||
f", Reward func {i} mean {avg_metrics[f'reward_func_{i}_mean']:.3f}, "
|
||||
f"Reward func {i} std {avg_metrics[f'reward_func_{i}_std']:.3f}"
|
||||
)
|
||||
|
||||
print(
|
||||
f"Iter {it}: Train loss {train_loss:.8f}, "
|
||||
f"Rewards {avg_metrics['rewards']:.3f}, "
|
||||
f"Rewards_std {avg_metrics['rewards_std']:.3f}, "
|
||||
f"Grouped Rewards {avg_metrics['grouped_rewards']:.3f}, "
|
||||
f"Grouped Rewards {avg_metrics['grouped_rewards']:.3f}, "
|
||||
f"Grouped Rewards_std {val_metrics['grouped_rewards_std']:.3f}, "
|
||||
f"KL {val_metrics['kl']:.3f}, "
|
||||
f"Iter {it}: {train_metrics_str}, "
|
||||
f"Learning Rate {learning_rate:.3e}, "
|
||||
f"It/sec {it_sec:.3f}, "
|
||||
f"Tokens/sec {tokens_sec:.3f}, "
|
||||
|
Loading…
Reference in New Issue
Block a user