From 4d0e52f7c889b5d8d4c121d4b74559ded5a3e655 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Sun, 26 Jan 2025 15:09:55 +0100 Subject: [PATCH] more metrics --- llms/mlx_lm/tuner/dpo_trainer.py | 72 ++++++++++++++++++++++---------- 1 file changed, 50 insertions(+), 22 deletions(-) diff --git a/llms/mlx_lm/tuner/dpo_trainer.py b/llms/mlx_lm/tuner/dpo_trainer.py index 8a3590fa..4ddc3d2e 100644 --- a/llms/mlx_lm/tuner/dpo_trainer.py +++ b/llms/mlx_lm/tuner/dpo_trainer.py @@ -113,14 +113,23 @@ def dpo_loss( else: raise ValueError(f"Unknown loss type: {loss_type}") - loss = mx.mean(losses) num_tokens = (num_chosen_tokens + num_rejected_tokens).sum() chosen_reward = beta * mx.mean(policy_chosen_score - reference_chosen_score) rejected_reward = beta * mx.mean(policy_rejected_score - reference_rejected_score) reward = mx.stack([chosen_reward, rejected_reward]) - return loss, reward, num_tokens + metrics = { + 'accuracies': mx.mean((chosen_reward > rejected_reward).astype(mx.float32)), + 'margins': mx.mean(chosen_reward - rejected_reward), + 'policy_rejected_logps': mx.mean(policy_rejected_score / num_rejected_tokens), + 'policy_chosen_logps': mx.mean(policy_chosen_score / num_chosen_tokens), + 'rejected_logits_mean': mx.mean(policy_rejected_score), + 'chosen_logits_mean': mx.mean(policy_chosen_score) + } + + + return mx.mean(losses), reward, num_tokens, metrics def iterate_dpo_batches(dataset, batch_size, max_seq_length, train=False): @@ -182,6 +191,7 @@ def evaluate_dpo( ): all_losses = 0 all_rewards = mx.zeros((2,)) + all_metrics = None ntokens = 0 index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) @@ -196,7 +206,7 @@ def evaluate_dpo( ): chosen, rejected, chosen_masks, rejected_masks = batch - loss, reward, toks = loss_fn( + loss, reward, toks, metrics = loss_fn( model=model, reference_teacher_model=reference_model, chosen=chosen, @@ -211,12 +221,23 @@ def evaluate_dpo( all_rewards += reward ntokens += toks + if all_metrics is None: + all_metrics = {k: v * toks for k, v in metrics.items()} + else: + for k, v in metrics.items(): + all_metrics[k] += v * toks + + mx.eval(all_losses, all_rewards, ntokens) all_losses = mx.distributed.all_sum(all_losses) all_rewards = mx.distributed.all_sum(all_rewards) - ntokens = mx.distributed.all_sum(ntokens) + all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()} - return (all_losses / ntokens).item(), all_rewards.tolist() + avg_metrics = {k: (v / ntokens).item() for k, v in all_metrics.items()} + avg_rewards = (all_rewards / ntokens).tolist() + avg_loss = (all_losses / ntokens).item() + + return avg_loss, avg_rewards, ntokens, avg_metrics def train_dpo( @@ -246,8 +267,7 @@ def train_dpo( def step(batch): chosen, rejected, chosen_masks, rejected_masks = batch - # Remove loss_type from the call - (loss, reward, toks), grad = loss_value_and_grad( + (loss, reward, toks, metrics), grad = loss_value_and_grad( model, reference_model, chosen, @@ -256,15 +276,11 @@ def train_dpo( rejected_masks ) - # All reduce the gradients if running in distributed mode grad = average_gradients(grad) - - # Model update optimizer.update(model, grad) - return loss, reward, toks + return loss, reward, toks, metrics - # Create a wrapper function that includes all required arguments def loss_wrapper(model, ref_model, chosen, rejected, chosen_masks, rejected_masks): return loss_fn( model=model, @@ -279,7 +295,6 @@ def train_dpo( is_reference_free=args.is_reference_free ) - # Create value_and_grad with the wrapper loss_value_and_grad = nn.value_and_grad(model, loss_wrapper) losses = 0 @@ -287,8 +302,15 @@ def train_dpo( n_tokens = 0 steps = 0 trained_tokens = 0 + accumulated_metrics = { + 'accuracies': 0, + 'margins': 0, + 'policy_rejected_logps': 0, + 'policy_chosen_logps': 0, + 'rejected_logits_mean': 0, + 'chosen_logits_mean': 0 + } - # Main training loop start = time.perf_counter() for it, batch in zip( range(1, args.iters + 1), @@ -302,7 +324,7 @@ def train_dpo( # Report validation loss if needed if it == 1 or it % args.steps_per_eval == 0 or it == args.iters: stop = time.perf_counter() - val_loss, val_rewards = evaluate_dpo( + val_loss, val_rewards, val_ntokens, val_metrics = evaluate_dpo( model=model, reference_model=reference_model, dataset=val_dataset, @@ -322,37 +344,40 @@ def train_dpo( f"Val loss {val_loss:.8f}, " f"Val chosen reward {val_rewards[0]:.3f}, " f"Val rejected reward {val_rewards[1]:.3f}, " + f"Val accuracy {val_metrics['accuracies']:.3f}, " + f"Val margin {val_metrics['margins']:.3f}, " f"Val took {val_time:.3f}s", flush=True, ) if training_callback is not None: - val_info = { + training_callback.on_val_loss_report({ "iteration": it, "val_loss": val_loss, "val_chosen_reward": val_rewards[0], "val_rejected_reward": val_rewards[1], + **{f"val_{k}": v for k, v in val_metrics.items()}, "val_time": val_time, - } - training_callback.on_val_loss_report(val_info) + }) start = time.perf_counter() - loss, reward, toks = step(batch) + loss, reward, toks, metrics = step(batch) losses += loss rewards += reward n_tokens += toks steps += 1 + for k, v in metrics.items(): + accumulated_metrics[k] += v mx.eval(state, losses, rewards, n_tokens) - # Report training loss if needed if it % args.steps_per_report == 0 or it == args.iters: stop = time.perf_counter() - train_loss = mx.distributed.all_sum(losses).item() - train_loss /= steps * world_size + train_loss = mx.distributed.all_sum(losses).item() / (steps * world_size) train_rewards = mx.distributed.all_sum(rewards).tolist() train_rewards = [r / (steps * world_size) for r in train_rewards] + avg_metrics = {k: v / (steps * world_size) for k, v in accumulated_metrics.items()} n_tokens = mx.distributed.all_sum(n_tokens).item() learning_rate = optimizer.learning_rate.item() it_sec = args.steps_per_report / (stop - start) @@ -365,6 +390,8 @@ def train_dpo( f"Iter {it}: Train loss {train_loss:.8f}, " f"Chosen reward {train_rewards[0]:.3f}, " f"Rejected reward {train_rewards[1]:.3f}, " + f"Accuracy {avg_metrics['accuracies']:.3f}, " + f"Margin {avg_metrics['margins']:.3f}, " f"Learning Rate {learning_rate:.3e}, " f"It/sec {it_sec:.3f}, " f"Tokens/sec {tokens_sec:.3f}, " @@ -379,6 +406,7 @@ def train_dpo( "train_loss": train_loss, "train_chosen_reward": train_rewards[0], "train_rejected_reward": train_rewards[1], + **{f"train_{k}": v for k, v in avg_metrics.items()}, "learning_rate": learning_rate, "iterations_per_second": it_sec, "tokens_per_second": tokens_sec,