starting fist training test run

This commit is contained in:
Goekdeniz-Guelmez 2025-02-03 10:08:28 +01:00
parent 41ff5364d7
commit 23d75cd7ad
3 changed files with 109 additions and 77 deletions

View File

@ -174,6 +174,7 @@ def build_parser():
)
parser.add_argument("--seed", type=int, help="The PRNG seed")
# GRPO args
parser.add_argument(
"--group-size",
type=int,
@ -270,12 +271,13 @@ def train_model(
if args.reference_model_path:
reference_model, _ = load(args.reference_model_path)
reference_model = reference_model.freeze()
else:
reference_model, _ = load(args.model)
reference_model, _ = None, None
train_grpo(
model=model,
reference_model=reference_model.freeze(),
ref_model=reference_model,
tokenizer=tokenizer,
optimizer=opt,
train_dataset=train_set,
@ -318,7 +320,7 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set
test_loss, test_rewards = evaluate_grpo(
model=model,
reference_model=reference_model,
ref_model=reference_model,
dataset=test_set,
tokenizer=tokenizer,
batch_size=args.batch_size,
@ -326,8 +328,7 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set
max_seq_length=args.max_seq_length,
beta=args.beta,
group_size=args.group_size,
epsilon=args.epsilon,
reference_model_path=args.reference_model_path
epsilon=args.epsilon
)
print(f"Test loss {test_loss:.3f}, Rewards: {test_rewards[0]:.3f}, {test_rewards[1]:.3f}")
else:

View File

@ -1,6 +1,6 @@
import json
from pathlib import Path
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Tuple
from transformers import PreTrainedTokenizer
@ -9,36 +9,30 @@ class GRPODataset:
"""
Dataset wrapper for GRPO training data.
Each example should have a 'prompt' and 'answer' field.
Returns data in (prompt, answer) tuple format required by GRPO trainer.
"""
def __init__(
self,
data: List[Dict[str, str]],
tokenizer: PreTrainedTokenizer,
prompt_key: str = "prompt",
prompt_key: str = "prompt",
answer_key: str = "answer"
):
self._data = []
for item in data:
# Tokenize prompt and answer
prompt_tokens = tokenizer.encode(item[prompt_key])
answer_tokens = tokenizer.encode(item[answer_key])
# Get prompt and answer text
prompt = str(item[prompt_key])
answer = str(item[answer_key])
# Add EOS tokens if needed
if prompt_tokens[-1] != tokenizer.eos_token_id:
prompt_tokens.append(tokenizer.eos_token_id)
if answer_tokens[-1] != tokenizer.eos_token_id:
answer_tokens.append(tokenizer.eos_token_id)
self._data.append({
'prompt': prompt_tokens,
'answer': answer_tokens
})
# Store as (prompt, answer) tuple
self._data.append((prompt, answer))
def __getitem__(self, idx: int) -> Dict[str, List[int]]:
def __getitem__(self, idx: int) -> Tuple[str, str]:
"""Returns a (prompt, answer) tuple for the given index."""
return self._data[idx]
def __len__(self) -> int:
"""Returns the number of examples in the dataset."""
return len(self._data)
@ -127,8 +121,11 @@ def create_dataset(
prompt_feature = prompt_feature or "prompt"
completion_feature = completion_feature or "completion"
sample = data[0]
if "messages" in sample:
return ChatDataset(data, tokenizer)
elif "prompt" in sample and "answer" in sample:
return GRPODataset(data, tokenizer, "prompt", "answer") # Use GRPO Dataset
elif prompt_feature in sample and completion_feature in sample:
return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature)
elif "text" in sample:

View File

@ -10,11 +10,10 @@ import mlx.nn as nn
import numpy as np
from mlx.utils import tree_flatten
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients, iterate_batches
from mlx_lm import generate
generate()
@dataclass
class GRPOTrainingArgs(TrainingArgs):
@ -263,55 +262,66 @@ def grpo_loss(
return loss, sequence_lengths.sum(), metrics
def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
# Sort by length:
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
def iterate_grpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
"""
Creates batches from prompt-answer pairs for GRPO training.
Args:
dataset: List of (prompt, answer) pairs
tokenizer: Tokenizer for processing inputs
batch_size: Size of each batch
max_seq_length: Maximum sequence length
train: Whether this is for training
Yields:
List of prompts for the current batch
"""
# Verify dataset is not empty and has correct format
if not dataset or not isinstance(dataset[0], (tuple, list)) or len(dataset[0]) != 2:
raise ValueError("Dataset must be a list of (prompt, answer) pairs")
# Sort by combined length of prompt + answer
idx = sorted(range(len(dataset)),
key=lambda i: len(dataset[i][0]) + len(dataset[i][1]))
if len(dataset) < batch_size:
raise ValueError(
f"Dataset must have at least batch_size={batch_size}"
f" examples but only has {len(dataset)}."
f"Dataset must have at least batch_size={batch_size} "
f"examples but only has {len(dataset)}."
)
# If running in distributed mode (N machines) then each one should skip N-1
# samples
# Handle distributed training
step = mx.distributed.init().size()
if batch_size % step != 0:
raise ValueError("The batch size must be divisible by the number of workers")
# Make the batches:
# Create batch indices
batch_idx = [
idx[i : i + batch_size : step]
for i in range(0, len(idx) - batch_size + 1, batch_size)
]
while True:
indices = np.random.permutation(len(batch_idx))
# Shuffle batch indices if training
indices = np.random.permutation(len(batch_idx)) if train else range(len(batch_idx))
for i in indices:
batch = [dataset[j] for j in batch_idx[i]]
lengths = [len(x) for x in batch]
if max(lengths) > max_seq_length:
# Get current batch of prompt-answer pairs
current_batch = [dataset[j] for j in batch_idx[i]]
# Extract prompts and answers
prompts = [pair[0] for pair in current_batch]
answers = [pair[1] for pair in current_batch]
if any(len(p) > max_seq_length for p in prompts):
print(
f"[WARNING] Some sequences are longer than {max_seq_length} tokens. "
f"The longest sentence {max(lengths)} will be truncated to {max_seq_length}. "
"Consider pre-splitting your data to save memory."
f"[WARNING] Some prompts are longer than {max_seq_length} tokens. "
"Long prompts will be truncated."
)
# Pad to the nearest multiple of 8 or the maximum length
pad_to = 8
max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to)
max_length_in_batch = min(max_length_in_batch, max_seq_length)
batch_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
for j in range(batch_size // step):
truncated_length = min(lengths[j], max_seq_length)
batch_arr[j, :truncated_length] = batch[j][:truncated_length]
lengths[j] = (
truncated_length # Update lengths to match truncated lengths
)
batch = mx.array(batch_arr)
yield batch[:, :-1], batch[:, 1:], mx.array(lengths)
# For GRPO, we only need to yield the prompts
# The answers will be used by the reward functions
yield prompts
if not train:
break
@ -325,12 +335,12 @@ def evaluate_grpo(
batch_size,
num_batches,
beta: float,
epslion: float,
epsilon: float,
group_size: int,
max_seq_length,
reward_funcs = None,
loss: callable = grpo_loss,
iterate_batches: callable = iterate_batches
iterate_batches: callable = iterate_grpo_batches
):
all_losses = 0
ntokens = 0
@ -354,7 +364,7 @@ def evaluate_grpo(
reward_funcs=reward_funcs,
beta=beta,
group_size=group_size,
epslion=epslion,
epsilon=epsilon,
ref_model=ref_model
)
all_losses += losses * toks
@ -394,10 +404,10 @@ def train_grpo(
],
args: GRPOTrainingArgs = GRPOTrainingArgs(),
loss: callable = grpo_loss,
iterate_batches: callable = iterate_batches,
iterate_batches: callable = iterate_grpo_batches,
training_callback: TrainingCallback = None,
):
print(f"Starting GRPO training..., iters: {args.iters}")
print(f"Starting GRPO training with {len(reward_funcs)} reward functions..., iters: {args.iters}")
world = mx.distributed.init()
world_size = world.size()
rank = world.rank()
@ -434,6 +444,9 @@ def train_grpo(
'grouped_rewards_std': 0,
'kl': 0
}
for i in range(len(reward_funcs)):
accumulated_metrics[f'reward_func_{i}_mean'] = 0
accumulated_metrics[f'reward_func_{i}_std'] = 0
start = time.perf_counter()
for it, batch in zip(
@ -454,26 +467,37 @@ def train_grpo(
model=model,
dataset=val_dataset,
loss=loss,
ref_model=model,
ref_model=ref_model,
reward_funcs=reward_funcs,
tokenizer=tokenizer,
group_size=args.group_size,
batch_size=args.batch_size,
num_batches=args.val_batches,
max_seq_length=args.max_seq_length,
beta=args.beta,
epsilon=args.epsilon,
iterate_batches=iterate_batches,
)
val_time = time.perf_counter() - stop
if rank == 0:
print(
f"Iter {it}: "
val_metrics_str = (
f"Val loss {val_loss:.8f}, "
f"Val rewards {val_metrics['rewards']:.3f}, "
f"Val rewards_std {val_metrics['rewards_std']:.3f}, "
f"Val grouped_rewards {val_metrics['grouped_rewards']:.3f}, "
f"Val total_rewards_mean {val_metrics['total_rewards_mean']:.3f}, "
f"Val total_rewards_std {val_metrics['total_rewards_std']:.3f}, "
f"Val grouped_rewards_mean {val_metrics['grouped_rewards_mean']:.3f}, "
f"Val grouped_rewards_std {val_metrics['grouped_rewards_std']:.3f}, "
f"Val kl {val_metrics['kl']:.3f}, "
f"Val kl {val_metrics['kl']:.3f}"
)
# Add reward function specific metrics
for i in range(len(reward_funcs)):
val_metrics_str += (
f", Val reward_func_{i}_mean {val_metrics[f'reward_func_{i}_mean']:.3f}, "
f"Val reward_func_{i}_std {val_metrics[f'reward_func_{i}_std']:.3f}"
)
print(
f"Iter {it}: {val_metrics_str}, "
f"Val took {val_time:.3f}s",
flush=True,
)
@ -510,14 +534,24 @@ def train_grpo(
peak_mem = mx.metal.get_peak_memory() / 1e9
if rank == 0:
train_metrics_str = (
f"Train loss {train_loss:.8f}, "
f"Total rewards mean {avg_metrics['total_rewards_mean']:.3f}, "
f"Total rewards std {avg_metrics['total_rewards_std']:.3f}, "
f"Grouped rewards mean {avg_metrics['grouped_rewards_mean']:.3f}, "
f"Grouped rewards std {avg_metrics['grouped_rewards_std']:.3f}, "
f"KL {avg_metrics['kl']:.3f}"
)
# Add reward function specific metrics
for i in range(len(reward_funcs)):
train_metrics_str += (
f", Reward func {i} mean {avg_metrics[f'reward_func_{i}_mean']:.3f}, "
f"Reward func {i} std {avg_metrics[f'reward_func_{i}_std']:.3f}"
)
print(
f"Iter {it}: Train loss {train_loss:.8f}, "
f"Rewards {avg_metrics['rewards']:.3f}, "
f"Rewards_std {avg_metrics['rewards_std']:.3f}, "
f"Grouped Rewards {avg_metrics['grouped_rewards']:.3f}, "
f"Grouped Rewards {avg_metrics['grouped_rewards']:.3f}, "
f"Grouped Rewards_std {val_metrics['grouped_rewards_std']:.3f}, "
f"KL {val_metrics['kl']:.3f}, "
f"Iter {it}: {train_metrics_str}, "
f"Learning Rate {learning_rate:.3e}, "
f"It/sec {it_sec:.3f}, "
f"Tokens/sec {tokens_sec:.3f}, "