From 595125ad4ebe9a0c4f08bf092f8c3a783dc464ff Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Fri, 31 Jan 2025 17:19:05 +0100 Subject: [PATCH] updates --- llms/mlx_lm/tuner/ppo_trainer.py | 217 ++++++++++++++++++------------- 1 file changed, 128 insertions(+), 89 deletions(-) diff --git a/llms/mlx_lm/tuner/ppo_trainer.py b/llms/mlx_lm/tuner/ppo_trainer.py index 63ca58bb..40dffe63 100644 --- a/llms/mlx_lm/tuner/ppo_trainer.py +++ b/llms/mlx_lm/tuner/ppo_trainer.py @@ -13,67 +13,92 @@ import numpy as np from mlx.nn.utils import average_gradients from mlx.utils import tree_flatten +from trainer import TrainingArgs, TrainingCallback, grad_checkpoint -def grad_checkpoint(layer): - """ - Update all instances of type(layer) to use gradient checkpointing. - """ - fn = type(layer).__call__ - def checkpointed_fn(model, *args, **kwargs): - def inner_fn(params, *args, **kwargs): - model.update(params) - return fn(model, *args, **kwargs) - return mx.checkpoint(inner_fn)(model.trainable_parameters(), *args, **kwargs) - - type(layer).__call__ = checkpointed_fn +def compute_ppo_loss( + new_logprobs: mx.array, + old_logprobs: mx.array, + values: mx.array, + old_values: mx.array, + advantages: mx.array, + returns: mx.array, + padding_mask: mx.array, + padding_mask_p1: mx.array = None, + vf_coef: float = 0.5, + cliprange: float = 0.2, + cliprange_value: float = 0.2 +) -> tuple[mx.array, mx.array, mx.array]: + """Compute PPO loss with policy and value components and masking""" + padding_mask_p1 = padding_mask_p1 if padding_mask_p1 is not None else padding_mask + + # Value loss + vpred_clipped = mx.clip(values, old_values - cliprange_value, old_values + cliprange_value) + vf_losses = mx.maximum( + mx.square(values - returns), + mx.square(vpred_clipped - returns) + ) + vf_loss = 0.5 * mx.mean(mx.where(~padding_mask_p1, vf_losses, 0)) + + # Policy loss + ratio = mx.exp(new_logprobs - old_logprobs) + pg_losses = mx.maximum( + -advantages * ratio, + -advantages * mx.clip(ratio, 1.0 - cliprange, 1.0 + cliprange) + ) + pg_loss = mx.mean(mx.where(~padding_mask, pg_losses, 0)) + + total_loss = pg_loss + vf_coef * vf_loss + return total_loss, pg_loss, vf_loss @dataclass -class TrainingArgs: - batch_size: int = field(default=4, metadata={"help": "Minibatch size."}) - iters: int = field(default=100, metadata={"help": "Iterations to train for."}) - val_batches: int = field( - default=25, - metadata={ - "help": "Number of validation batches, -1 uses the entire validation set." - }, - ) - steps_per_report: int = field( - default=10, - metadata={"help": "Number of training steps between loss reporting."}, - ) - steps_per_eval: int = field( - default=200, metadata={"help": "Number of training steps between validations."} - ) - steps_per_save: int = field( - default=100, metadata={"help": "Save the model every number steps"} - ) - max_seq_length: int = field( - default=2048, metadata={"help": "Maximum sequence length."} - ) - adapter_file: str = field( - default="adapters.safetensors", - metadata={"help": "Save/load path for the trained adapter weights."}, - ) - grad_checkpoint: bool = field( - default=False, - metadata={"help": "Use gradient checkpointing to reduce memory use."}, - ) +class PPOTrainingArgs(TrainingArgs): + vf_coef: float = field(default=0.5, metadata={"help": "Value function coefficient"}) + cliprange: float = field(default=0.2, metadata={"help": "Policy gradient clipping range"}) + cliprange_value: float = field(default=0.2, metadata={"help": "Value function clipping range"}) -def default_loss(model, inputs, targets, lengths): - logits = model(inputs) - logits = logits.astype(mx.float32) +def ppo_loss( + model, + inputs, + targets, + lengths, + old_logprobs, + values, + old_values, + advantages, + returns, + vf_coef=0.5, + cliprange=0.2, + cliprange_value=0.2 +): + # Get new logits and create length mask + logits = model(inputs).astype(mx.float32) + length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None] + + # Get new log probs + new_logprobs = nn.losses.cross_entropy(logits, targets) * length_mask + ntoks = length_mask.sum() + new_logprobs = new_logprobs.sum() / ntoks - length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None] + # Value loss with clipping + vpred_clipped = mx.clip(values, old_values - cliprange_value, old_values + cliprange_value) + vf_loss = 0.5 * mx.maximum( + mx.square(values - returns), + mx.square(vpred_clipped - returns) + ).mean() - ce = nn.losses.cross_entropy(logits, targets) * length_mask - ntoks = length_mask.sum() - ce = ce.sum() / ntoks + # Policy loss with clipping + ratio = mx.exp(new_logprobs - old_logprobs) + pg_loss = mx.maximum( + -advantages * ratio, + -advantages * mx.clip(ratio, 1.0 - cliprange, 1.0 + cliprange) + ).mean() - return ce, ntoks + total_loss = pg_loss + vf_coef * vf_loss + return total_loss, pg_loss, vf_loss, ntoks def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): @@ -131,49 +156,63 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False) def evaluate( - model, - dataset, - tokenizer, - batch_size, - num_batches, - max_seq_length=2048, - loss: callable = default_loss, - iterate_batches: callable = iterate_batches, + model, + dataset, + tokenizer, + batch_size, + num_batches, + max_seq_length=2048, + old_logprobs=None, + values=None, + old_values=None, + advantages=None, + returns=None, + vf_coef=0.5, + cliprange=0.2, + cliprange_value=0.2, + loss: callable = compute_ppo_loss, + iterate_batches: callable = iterate_batches, ): - all_losses = 0 - ntokens = 0 + total_loss = 0 + total_pg_loss = 0 + total_vf_loss = 0 + ntokens = 0 - index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) + index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) - for _, batch in zip( - index_iterator, - iterate_batches( - dataset=dataset, - tokenizer=tokenizer, - batch_size=batch_size, - max_seq_length=max_seq_length, - ), - ): - losses, toks = loss(model, *batch) - all_losses += losses * toks - ntokens += toks - mx.eval(all_losses, ntokens) + for _, batch in zip( + index_iterator, + iterate_batches( + dataset=dataset, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_length=max_seq_length, + ), + ): + losses, pg_loss, vf_loss, toks = loss( + model, *batch, + old_logprobs=old_logprobs, + values=values, + old_values=old_values, + advantages=advantages, + returns=returns, + vf_coef=vf_coef, + cliprange=cliprange, + cliprange_value=cliprange_value + ) + + total_loss += losses * toks + total_pg_loss += pg_loss * toks + total_vf_loss += vf_loss * toks + ntokens += toks + mx.eval(total_loss, total_pg_loss, total_vf_loss, ntokens) - all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu) - ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu) + total_loss = mx.distributed.all_sum(total_loss, stream=mx.cpu) + total_pg_loss = mx.distributed.all_sum(total_pg_loss, stream=mx.cpu) + total_vf_loss = mx.distributed.all_sum(total_vf_loss, stream=mx.cpu) + ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu) - return (all_losses / ntokens).item() - - -class TrainingCallback: - - def on_train_loss_report(self, train_info: dict): - """Called to report training loss at specified intervals.""" - pass - - def on_val_loss_report(self, val_info: dict): - """Called to report validation loss at specified intervals or the beginning.""" - pass + return (total_loss / ntokens).item(), (total_pg_loss / ntokens).item(), (total_vf_loss / ntokens).item() def train( @@ -183,7 +222,7 @@ def train( train_dataset, val_dataset, args: TrainingArgs = TrainingArgs(), - loss: callable = default_loss, + loss: callable = ppo_loss, iterate_batches: callable = iterate_batches, training_callback: TrainingCallback = None, ):