From 23d75cd7adfa48b1a34836e2367350026279e93b Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Mon, 3 Feb 2025 10:08:28 +0100 Subject: [PATCH] starting fist training test run --- llms/mlx_lm/lora.py | 11 +-- llms/mlx_lm/tuner/datasets.py | 31 +++---- llms/mlx_lm/tuner/grpo_trainer.py | 144 ++++++++++++++++++------------ 3 files changed, 109 insertions(+), 77 deletions(-) diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index a5eb8ffe..1f684d27 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -174,6 +174,7 @@ def build_parser(): ) parser.add_argument("--seed", type=int, help="The PRNG seed") + # GRPO args parser.add_argument( "--group-size", type=int, @@ -270,12 +271,13 @@ def train_model( if args.reference_model_path: reference_model, _ = load(args.reference_model_path) + reference_model = reference_model.freeze() else: - reference_model, _ = load(args.model) + reference_model, _ = None, None train_grpo( model=model, - reference_model=reference_model.freeze(), + ref_model=reference_model, tokenizer=tokenizer, optimizer=opt, train_dataset=train_set, @@ -318,7 +320,7 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set test_loss, test_rewards = evaluate_grpo( model=model, - reference_model=reference_model, + ref_model=reference_model, dataset=test_set, tokenizer=tokenizer, batch_size=args.batch_size, @@ -326,8 +328,7 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set max_seq_length=args.max_seq_length, beta=args.beta, group_size=args.group_size, - epsilon=args.epsilon, - reference_model_path=args.reference_model_path + epsilon=args.epsilon ) print(f"Test loss {test_loss:.3f}, Rewards: {test_rewards[0]:.3f}, {test_rewards[1]:.3f}") else: diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 8f185473..0c19031b 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -1,6 +1,6 @@ import json from pathlib import Path -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple from transformers import PreTrainedTokenizer @@ -9,36 +9,30 @@ class GRPODataset: """ Dataset wrapper for GRPO training data. Each example should have a 'prompt' and 'answer' field. + Returns data in (prompt, answer) tuple format required by GRPO trainer. """ def __init__( self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer, - prompt_key: str = "prompt", + prompt_key: str = "prompt", answer_key: str = "answer" ): self._data = [] - for item in data: - # Tokenize prompt and answer - prompt_tokens = tokenizer.encode(item[prompt_key]) - answer_tokens = tokenizer.encode(item[answer_key]) + # Get prompt and answer text + prompt = str(item[prompt_key]) + answer = str(item[answer_key]) - # Add EOS tokens if needed - if prompt_tokens[-1] != tokenizer.eos_token_id: - prompt_tokens.append(tokenizer.eos_token_id) - if answer_tokens[-1] != tokenizer.eos_token_id: - answer_tokens.append(tokenizer.eos_token_id) - - self._data.append({ - 'prompt': prompt_tokens, - 'answer': answer_tokens - }) + # Store as (prompt, answer) tuple + self._data.append((prompt, answer)) - def __getitem__(self, idx: int) -> Dict[str, List[int]]: + def __getitem__(self, idx: int) -> Tuple[str, str]: + """Returns a (prompt, answer) tuple for the given index.""" return self._data[idx] def __len__(self) -> int: + """Returns the number of examples in the dataset.""" return len(self._data) @@ -127,8 +121,11 @@ def create_dataset( prompt_feature = prompt_feature or "prompt" completion_feature = completion_feature or "completion" sample = data[0] + if "messages" in sample: return ChatDataset(data, tokenizer) + elif "prompt" in sample and "answer" in sample: + return GRPODataset(data, tokenizer, "prompt", "answer") # Use GRPO Dataset elif prompt_feature in sample and completion_feature in sample: return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature) elif "text" in sample: diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index 31edc0ec..e997b504 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -10,11 +10,10 @@ import mlx.nn as nn import numpy as np from mlx.utils import tree_flatten -from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients +from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients, iterate_batches from mlx_lm import generate -generate() @dataclass class GRPOTrainingArgs(TrainingArgs): @@ -263,55 +262,66 @@ def grpo_loss( return loss, sequence_lengths.sum(), metrics -def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): - # Sort by length: - idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx])) +def iterate_grpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): + """ + Creates batches from prompt-answer pairs for GRPO training. + + Args: + dataset: List of (prompt, answer) pairs + tokenizer: Tokenizer for processing inputs + batch_size: Size of each batch + max_seq_length: Maximum sequence length + train: Whether this is for training + + Yields: + List of prompts for the current batch + """ + # Verify dataset is not empty and has correct format + if not dataset or not isinstance(dataset[0], (tuple, list)) or len(dataset[0]) != 2: + raise ValueError("Dataset must be a list of (prompt, answer) pairs") + + # Sort by combined length of prompt + answer + idx = sorted(range(len(dataset)), + key=lambda i: len(dataset[i][0]) + len(dataset[i][1])) + if len(dataset) < batch_size: raise ValueError( - f"Dataset must have at least batch_size={batch_size}" - f" examples but only has {len(dataset)}." + f"Dataset must have at least batch_size={batch_size} " + f"examples but only has {len(dataset)}." ) - # If running in distributed mode (N machines) then each one should skip N-1 - # samples + # Handle distributed training step = mx.distributed.init().size() if batch_size % step != 0: raise ValueError("The batch size must be divisible by the number of workers") - # Make the batches: + # Create batch indices batch_idx = [ idx[i : i + batch_size : step] for i in range(0, len(idx) - batch_size + 1, batch_size) ] while True: - indices = np.random.permutation(len(batch_idx)) + # Shuffle batch indices if training + indices = np.random.permutation(len(batch_idx)) if train else range(len(batch_idx)) + for i in indices: - batch = [dataset[j] for j in batch_idx[i]] - lengths = [len(x) for x in batch] - if max(lengths) > max_seq_length: + # Get current batch of prompt-answer pairs + current_batch = [dataset[j] for j in batch_idx[i]] + + # Extract prompts and answers + prompts = [pair[0] for pair in current_batch] + answers = [pair[1] for pair in current_batch] + + if any(len(p) > max_seq_length for p in prompts): print( - f"[WARNING] Some sequences are longer than {max_seq_length} tokens. " - f"The longest sentence {max(lengths)} will be truncated to {max_seq_length}. " - "Consider pre-splitting your data to save memory." + f"[WARNING] Some prompts are longer than {max_seq_length} tokens. " + "Long prompts will be truncated." ) - - # Pad to the nearest multiple of 8 or the maximum length - pad_to = 8 - max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to) - max_length_in_batch = min(max_length_in_batch, max_seq_length) - - batch_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32) - - for j in range(batch_size // step): - truncated_length = min(lengths[j], max_seq_length) - batch_arr[j, :truncated_length] = batch[j][:truncated_length] - lengths[j] = ( - truncated_length # Update lengths to match truncated lengths - ) - batch = mx.array(batch_arr) - - yield batch[:, :-1], batch[:, 1:], mx.array(lengths) + + # For GRPO, we only need to yield the prompts + # The answers will be used by the reward functions + yield prompts if not train: break @@ -325,12 +335,12 @@ def evaluate_grpo( batch_size, num_batches, beta: float, - epslion: float, + epsilon: float, group_size: int, max_seq_length, reward_funcs = None, loss: callable = grpo_loss, - iterate_batches: callable = iterate_batches + iterate_batches: callable = iterate_grpo_batches ): all_losses = 0 ntokens = 0 @@ -354,7 +364,7 @@ def evaluate_grpo( reward_funcs=reward_funcs, beta=beta, group_size=group_size, - epslion=epslion, + epsilon=epsilon, ref_model=ref_model ) all_losses += losses * toks @@ -394,10 +404,10 @@ def train_grpo( ], args: GRPOTrainingArgs = GRPOTrainingArgs(), loss: callable = grpo_loss, - iterate_batches: callable = iterate_batches, + iterate_batches: callable = iterate_grpo_batches, training_callback: TrainingCallback = None, ): - print(f"Starting GRPO training..., iters: {args.iters}") + print(f"Starting GRPO training with {len(reward_funcs)} reward functions..., iters: {args.iters}") world = mx.distributed.init() world_size = world.size() rank = world.rank() @@ -434,6 +444,9 @@ def train_grpo( 'grouped_rewards_std': 0, 'kl': 0 } + for i in range(len(reward_funcs)): + accumulated_metrics[f'reward_func_{i}_mean'] = 0 + accumulated_metrics[f'reward_func_{i}_std'] = 0 start = time.perf_counter() for it, batch in zip( @@ -454,26 +467,37 @@ def train_grpo( model=model, dataset=val_dataset, loss=loss, - - ref_model=model, + ref_model=ref_model, reward_funcs=reward_funcs, - tokenizer=tokenizer, + group_size=args.group_size, batch_size=args.batch_size, num_batches=args.val_batches, max_seq_length=args.max_seq_length, + beta=args.beta, + epsilon=args.epsilon, iterate_batches=iterate_batches, ) val_time = time.perf_counter() - stop if rank == 0: - print( - f"Iter {it}: " + val_metrics_str = ( f"Val loss {val_loss:.8f}, " - f"Val rewards {val_metrics['rewards']:.3f}, " - f"Val rewards_std {val_metrics['rewards_std']:.3f}, " - f"Val grouped_rewards {val_metrics['grouped_rewards']:.3f}, " + f"Val total_rewards_mean {val_metrics['total_rewards_mean']:.3f}, " + f"Val total_rewards_std {val_metrics['total_rewards_std']:.3f}, " + f"Val grouped_rewards_mean {val_metrics['grouped_rewards_mean']:.3f}, " f"Val grouped_rewards_std {val_metrics['grouped_rewards_std']:.3f}, " - f"Val kl {val_metrics['kl']:.3f}, " + f"Val kl {val_metrics['kl']:.3f}" + ) + + # Add reward function specific metrics + for i in range(len(reward_funcs)): + val_metrics_str += ( + f", Val reward_func_{i}_mean {val_metrics[f'reward_func_{i}_mean']:.3f}, " + f"Val reward_func_{i}_std {val_metrics[f'reward_func_{i}_std']:.3f}" + ) + + print( + f"Iter {it}: {val_metrics_str}, " f"Val took {val_time:.3f}s", flush=True, ) @@ -510,14 +534,24 @@ def train_grpo( peak_mem = mx.metal.get_peak_memory() / 1e9 if rank == 0: + train_metrics_str = ( + f"Train loss {train_loss:.8f}, " + f"Total rewards mean {avg_metrics['total_rewards_mean']:.3f}, " + f"Total rewards std {avg_metrics['total_rewards_std']:.3f}, " + f"Grouped rewards mean {avg_metrics['grouped_rewards_mean']:.3f}, " + f"Grouped rewards std {avg_metrics['grouped_rewards_std']:.3f}, " + f"KL {avg_metrics['kl']:.3f}" + ) + + # Add reward function specific metrics + for i in range(len(reward_funcs)): + train_metrics_str += ( + f", Reward func {i} mean {avg_metrics[f'reward_func_{i}_mean']:.3f}, " + f"Reward func {i} std {avg_metrics[f'reward_func_{i}_std']:.3f}" + ) + print( - f"Iter {it}: Train loss {train_loss:.8f}, " - f"Rewards {avg_metrics['rewards']:.3f}, " - f"Rewards_std {avg_metrics['rewards_std']:.3f}, " - f"Grouped Rewards {avg_metrics['grouped_rewards']:.3f}, " - f"Grouped Rewards {avg_metrics['grouped_rewards']:.3f}, " - f"Grouped Rewards_std {val_metrics['grouped_rewards_std']:.3f}, " - f"KL {val_metrics['kl']:.3f}, " + f"Iter {it}: {train_metrics_str}, " f"Learning Rate {learning_rate:.3e}, " f"It/sec {it_sec:.3f}, " f"Tokens/sec {tokens_sec:.3f}, "