starting fist training test run

This commit is contained in:
Goekdeniz-Guelmez 2025-02-03 10:08:28 +01:00
parent 41ff5364d7
commit 23d75cd7ad
3 changed files with 109 additions and 77 deletions

View File

@ -174,6 +174,7 @@ def build_parser():
) )
parser.add_argument("--seed", type=int, help="The PRNG seed") parser.add_argument("--seed", type=int, help="The PRNG seed")
# GRPO args
parser.add_argument( parser.add_argument(
"--group-size", "--group-size",
type=int, type=int,
@ -270,12 +271,13 @@ def train_model(
if args.reference_model_path: if args.reference_model_path:
reference_model, _ = load(args.reference_model_path) reference_model, _ = load(args.reference_model_path)
reference_model = reference_model.freeze()
else: else:
reference_model, _ = load(args.model) reference_model, _ = None, None
train_grpo( train_grpo(
model=model, model=model,
reference_model=reference_model.freeze(), ref_model=reference_model,
tokenizer=tokenizer, tokenizer=tokenizer,
optimizer=opt, optimizer=opt,
train_dataset=train_set, train_dataset=train_set,
@ -318,7 +320,7 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set
test_loss, test_rewards = evaluate_grpo( test_loss, test_rewards = evaluate_grpo(
model=model, model=model,
reference_model=reference_model, ref_model=reference_model,
dataset=test_set, dataset=test_set,
tokenizer=tokenizer, tokenizer=tokenizer,
batch_size=args.batch_size, batch_size=args.batch_size,
@ -326,8 +328,7 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set
max_seq_length=args.max_seq_length, max_seq_length=args.max_seq_length,
beta=args.beta, beta=args.beta,
group_size=args.group_size, group_size=args.group_size,
epsilon=args.epsilon, epsilon=args.epsilon
reference_model_path=args.reference_model_path
) )
print(f"Test loss {test_loss:.3f}, Rewards: {test_rewards[0]:.3f}, {test_rewards[1]:.3f}") print(f"Test loss {test_loss:.3f}, Rewards: {test_rewards[0]:.3f}, {test_rewards[1]:.3f}")
else: else:

View File

@ -1,6 +1,6 @@
import json import json
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional from typing import Dict, List, Optional, Tuple
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
@ -9,36 +9,30 @@ class GRPODataset:
""" """
Dataset wrapper for GRPO training data. Dataset wrapper for GRPO training data.
Each example should have a 'prompt' and 'answer' field. Each example should have a 'prompt' and 'answer' field.
Returns data in (prompt, answer) tuple format required by GRPO trainer.
""" """
def __init__( def __init__(
self, self,
data: List[Dict[str, str]], data: List[Dict[str, str]],
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
prompt_key: str = "prompt", prompt_key: str = "prompt",
answer_key: str = "answer" answer_key: str = "answer"
): ):
self._data = [] self._data = []
for item in data: for item in data:
# Tokenize prompt and answer # Get prompt and answer text
prompt_tokens = tokenizer.encode(item[prompt_key]) prompt = str(item[prompt_key])
answer_tokens = tokenizer.encode(item[answer_key]) answer = str(item[answer_key])
# Add EOS tokens if needed # Store as (prompt, answer) tuple
if prompt_tokens[-1] != tokenizer.eos_token_id: self._data.append((prompt, answer))
prompt_tokens.append(tokenizer.eos_token_id)
if answer_tokens[-1] != tokenizer.eos_token_id:
answer_tokens.append(tokenizer.eos_token_id)
self._data.append({
'prompt': prompt_tokens,
'answer': answer_tokens
})
def __getitem__(self, idx: int) -> Dict[str, List[int]]: def __getitem__(self, idx: int) -> Tuple[str, str]:
"""Returns a (prompt, answer) tuple for the given index."""
return self._data[idx] return self._data[idx]
def __len__(self) -> int: def __len__(self) -> int:
"""Returns the number of examples in the dataset."""
return len(self._data) return len(self._data)
@ -127,8 +121,11 @@ def create_dataset(
prompt_feature = prompt_feature or "prompt" prompt_feature = prompt_feature or "prompt"
completion_feature = completion_feature or "completion" completion_feature = completion_feature or "completion"
sample = data[0] sample = data[0]
if "messages" in sample: if "messages" in sample:
return ChatDataset(data, tokenizer) return ChatDataset(data, tokenizer)
elif "prompt" in sample and "answer" in sample:
return GRPODataset(data, tokenizer, "prompt", "answer") # Use GRPO Dataset
elif prompt_feature in sample and completion_feature in sample: elif prompt_feature in sample and completion_feature in sample:
return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature) return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature)
elif "text" in sample: elif "text" in sample:

View File

@ -10,11 +10,10 @@ import mlx.nn as nn
import numpy as np import numpy as np
from mlx.utils import tree_flatten from mlx.utils import tree_flatten
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients, iterate_batches
from mlx_lm import generate from mlx_lm import generate
generate()
@dataclass @dataclass
class GRPOTrainingArgs(TrainingArgs): class GRPOTrainingArgs(TrainingArgs):
@ -263,55 +262,66 @@ def grpo_loss(
return loss, sequence_lengths.sum(), metrics return loss, sequence_lengths.sum(), metrics
def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): def iterate_grpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
# Sort by length: """
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx])) Creates batches from prompt-answer pairs for GRPO training.
Args:
dataset: List of (prompt, answer) pairs
tokenizer: Tokenizer for processing inputs
batch_size: Size of each batch
max_seq_length: Maximum sequence length
train: Whether this is for training
Yields:
List of prompts for the current batch
"""
# Verify dataset is not empty and has correct format
if not dataset or not isinstance(dataset[0], (tuple, list)) or len(dataset[0]) != 2:
raise ValueError("Dataset must be a list of (prompt, answer) pairs")
# Sort by combined length of prompt + answer
idx = sorted(range(len(dataset)),
key=lambda i: len(dataset[i][0]) + len(dataset[i][1]))
if len(dataset) < batch_size: if len(dataset) < batch_size:
raise ValueError( raise ValueError(
f"Dataset must have at least batch_size={batch_size}" f"Dataset must have at least batch_size={batch_size} "
f" examples but only has {len(dataset)}." f"examples but only has {len(dataset)}."
) )
# If running in distributed mode (N machines) then each one should skip N-1 # Handle distributed training
# samples
step = mx.distributed.init().size() step = mx.distributed.init().size()
if batch_size % step != 0: if batch_size % step != 0:
raise ValueError("The batch size must be divisible by the number of workers") raise ValueError("The batch size must be divisible by the number of workers")
# Make the batches: # Create batch indices
batch_idx = [ batch_idx = [
idx[i : i + batch_size : step] idx[i : i + batch_size : step]
for i in range(0, len(idx) - batch_size + 1, batch_size) for i in range(0, len(idx) - batch_size + 1, batch_size)
] ]
while True: while True:
indices = np.random.permutation(len(batch_idx)) # Shuffle batch indices if training
indices = np.random.permutation(len(batch_idx)) if train else range(len(batch_idx))
for i in indices: for i in indices:
batch = [dataset[j] for j in batch_idx[i]] # Get current batch of prompt-answer pairs
lengths = [len(x) for x in batch] current_batch = [dataset[j] for j in batch_idx[i]]
if max(lengths) > max_seq_length:
# Extract prompts and answers
prompts = [pair[0] for pair in current_batch]
answers = [pair[1] for pair in current_batch]
if any(len(p) > max_seq_length for p in prompts):
print( print(
f"[WARNING] Some sequences are longer than {max_seq_length} tokens. " f"[WARNING] Some prompts are longer than {max_seq_length} tokens. "
f"The longest sentence {max(lengths)} will be truncated to {max_seq_length}. " "Long prompts will be truncated."
"Consider pre-splitting your data to save memory."
) )
# Pad to the nearest multiple of 8 or the maximum length # For GRPO, we only need to yield the prompts
pad_to = 8 # The answers will be used by the reward functions
max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to) yield prompts
max_length_in_batch = min(max_length_in_batch, max_seq_length)
batch_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
for j in range(batch_size // step):
truncated_length = min(lengths[j], max_seq_length)
batch_arr[j, :truncated_length] = batch[j][:truncated_length]
lengths[j] = (
truncated_length # Update lengths to match truncated lengths
)
batch = mx.array(batch_arr)
yield batch[:, :-1], batch[:, 1:], mx.array(lengths)
if not train: if not train:
break break
@ -325,12 +335,12 @@ def evaluate_grpo(
batch_size, batch_size,
num_batches, num_batches,
beta: float, beta: float,
epslion: float, epsilon: float,
group_size: int, group_size: int,
max_seq_length, max_seq_length,
reward_funcs = None, reward_funcs = None,
loss: callable = grpo_loss, loss: callable = grpo_loss,
iterate_batches: callable = iterate_batches iterate_batches: callable = iterate_grpo_batches
): ):
all_losses = 0 all_losses = 0
ntokens = 0 ntokens = 0
@ -354,7 +364,7 @@ def evaluate_grpo(
reward_funcs=reward_funcs, reward_funcs=reward_funcs,
beta=beta, beta=beta,
group_size=group_size, group_size=group_size,
epslion=epslion, epsilon=epsilon,
ref_model=ref_model ref_model=ref_model
) )
all_losses += losses * toks all_losses += losses * toks
@ -394,10 +404,10 @@ def train_grpo(
], ],
args: GRPOTrainingArgs = GRPOTrainingArgs(), args: GRPOTrainingArgs = GRPOTrainingArgs(),
loss: callable = grpo_loss, loss: callable = grpo_loss,
iterate_batches: callable = iterate_batches, iterate_batches: callable = iterate_grpo_batches,
training_callback: TrainingCallback = None, training_callback: TrainingCallback = None,
): ):
print(f"Starting GRPO training..., iters: {args.iters}") print(f"Starting GRPO training with {len(reward_funcs)} reward functions..., iters: {args.iters}")
world = mx.distributed.init() world = mx.distributed.init()
world_size = world.size() world_size = world.size()
rank = world.rank() rank = world.rank()
@ -434,6 +444,9 @@ def train_grpo(
'grouped_rewards_std': 0, 'grouped_rewards_std': 0,
'kl': 0 'kl': 0
} }
for i in range(len(reward_funcs)):
accumulated_metrics[f'reward_func_{i}_mean'] = 0
accumulated_metrics[f'reward_func_{i}_std'] = 0
start = time.perf_counter() start = time.perf_counter()
for it, batch in zip( for it, batch in zip(
@ -454,26 +467,37 @@ def train_grpo(
model=model, model=model,
dataset=val_dataset, dataset=val_dataset,
loss=loss, loss=loss,
ref_model=ref_model,
ref_model=model,
reward_funcs=reward_funcs, reward_funcs=reward_funcs,
tokenizer=tokenizer, tokenizer=tokenizer,
group_size=args.group_size,
batch_size=args.batch_size, batch_size=args.batch_size,
num_batches=args.val_batches, num_batches=args.val_batches,
max_seq_length=args.max_seq_length, max_seq_length=args.max_seq_length,
beta=args.beta,
epsilon=args.epsilon,
iterate_batches=iterate_batches, iterate_batches=iterate_batches,
) )
val_time = time.perf_counter() - stop val_time = time.perf_counter() - stop
if rank == 0: if rank == 0:
print( val_metrics_str = (
f"Iter {it}: "
f"Val loss {val_loss:.8f}, " f"Val loss {val_loss:.8f}, "
f"Val rewards {val_metrics['rewards']:.3f}, " f"Val total_rewards_mean {val_metrics['total_rewards_mean']:.3f}, "
f"Val rewards_std {val_metrics['rewards_std']:.3f}, " f"Val total_rewards_std {val_metrics['total_rewards_std']:.3f}, "
f"Val grouped_rewards {val_metrics['grouped_rewards']:.3f}, " f"Val grouped_rewards_mean {val_metrics['grouped_rewards_mean']:.3f}, "
f"Val grouped_rewards_std {val_metrics['grouped_rewards_std']:.3f}, " f"Val grouped_rewards_std {val_metrics['grouped_rewards_std']:.3f}, "
f"Val kl {val_metrics['kl']:.3f}, " f"Val kl {val_metrics['kl']:.3f}"
)
# Add reward function specific metrics
for i in range(len(reward_funcs)):
val_metrics_str += (
f", Val reward_func_{i}_mean {val_metrics[f'reward_func_{i}_mean']:.3f}, "
f"Val reward_func_{i}_std {val_metrics[f'reward_func_{i}_std']:.3f}"
)
print(
f"Iter {it}: {val_metrics_str}, "
f"Val took {val_time:.3f}s", f"Val took {val_time:.3f}s",
flush=True, flush=True,
) )
@ -510,14 +534,24 @@ def train_grpo(
peak_mem = mx.metal.get_peak_memory() / 1e9 peak_mem = mx.metal.get_peak_memory() / 1e9
if rank == 0: if rank == 0:
train_metrics_str = (
f"Train loss {train_loss:.8f}, "
f"Total rewards mean {avg_metrics['total_rewards_mean']:.3f}, "
f"Total rewards std {avg_metrics['total_rewards_std']:.3f}, "
f"Grouped rewards mean {avg_metrics['grouped_rewards_mean']:.3f}, "
f"Grouped rewards std {avg_metrics['grouped_rewards_std']:.3f}, "
f"KL {avg_metrics['kl']:.3f}"
)
# Add reward function specific metrics
for i in range(len(reward_funcs)):
train_metrics_str += (
f", Reward func {i} mean {avg_metrics[f'reward_func_{i}_mean']:.3f}, "
f"Reward func {i} std {avg_metrics[f'reward_func_{i}_std']:.3f}"
)
print( print(
f"Iter {it}: Train loss {train_loss:.8f}, " f"Iter {it}: {train_metrics_str}, "
f"Rewards {avg_metrics['rewards']:.3f}, "
f"Rewards_std {avg_metrics['rewards_std']:.3f}, "
f"Grouped Rewards {avg_metrics['grouped_rewards']:.3f}, "
f"Grouped Rewards {avg_metrics['grouped_rewards']:.3f}, "
f"Grouped Rewards_std {val_metrics['grouped_rewards_std']:.3f}, "
f"KL {val_metrics['kl']:.3f}, "
f"Learning Rate {learning_rate:.3e}, " f"Learning Rate {learning_rate:.3e}, "
f"It/sec {it_sec:.3f}, " f"It/sec {it_sec:.3f}, "
f"Tokens/sec {tokens_sec:.3f}, " f"Tokens/sec {tokens_sec:.3f}, "