# Copyright © 2024 Apple Inc. import glob import shutil import time from dataclasses import dataclass, field from pathlib import Path from typing import Union import mlx.core as mx import mlx.nn as nn import numpy as np from mlx.nn.utils import average_gradients from mlx.utils import tree_flatten from .trainer import TrainingCallback, grad_checkpoint, TrainingArgs @dataclass class DPOTrainingArgs(TrainingArgs): beta: float = field( default=0.1, metadata={"help": "Temperature parameter for DPO training."} ) loss_type: str = field( default="sigmoid", metadata={ "help": "DPO loss type: 'sigmoid', 'hinge', 'ipo', or 'dpop'." } ) delta: float = field( default=50.0, metadata={ "help": "Delta parameter for DPOP loss type." } ) reference_model_path: str = field( default=None, metadata={ "help": "Path to reference model weights. If None, uses the same model." } ) def dpo_loss( model, chosen: mx.array, rejected: mx.array, chosen_masks: mx.array, rejected_masks: mx.array, beta: float, delta: float, loss_type: str = "sigmoid", ref_model=None, ): def make_predictions(model, x, mask): inputs = x[:, :-1] targets = x[:, 1:] logits = model(inputs) logits = logits.astype(mx.float32) return -nn.losses.cross_entropy(logits, targets) * mask[:, :-1] num_chosen_tokens = chosen_masks.sum(-1) num_rejected_tokens = rejected_masks.sum(-1) # Calculate log probabilities for policy model policy_chosen_scores = make_predictions(model, chosen, chosen_masks) policy_rejected_scores = make_predictions(model, rejected, rejected_masks) if loss_type == "ipo": # ipo uses average log probabilities policy_chosen_score = policy_chosen_scores.sum(-1) / num_chosen_tokens policy_rejected_score = policy_rejected_scores.sum(-1) / num_rejected_tokens else: policy_chosen_score = policy_chosen_scores.sum(-1) policy_rejected_score = policy_rejected_scores.sum(-1) # Calculate log probabilities for reference model if ref_model is None: reference_chosen_score = mx.zeros_like(policy_chosen_score) reference_rejected_score = mx.zeros_like(policy_rejected_score) else: reference_chosen_scores = mx.stop_gradient(make_predictions(ref_model, chosen, chosen_masks)) reference_rejected_scores = mx.stop_gradient(make_predictions(ref_model, rejected, rejected_masks)) if loss_type == "ipo": # ipo uses average log probabilities reference_chosen_score = reference_chosen_scores.sum(-1) / num_chosen_tokens reference_rejected_score = reference_rejected_scores.sum(-1) / num_rejected_tokens else: reference_chosen_score = reference_chosen_scores.sum(-1) reference_rejected_score = reference_rejected_scores.sum(-1) logits = (policy_chosen_score - policy_rejected_score) - (reference_chosen_score - reference_rejected_score) if loss_type == "sigmoid": # From the og paper losses = -nn.log_sigmoid(beta * logits) elif loss_type == "hinge": losses = nn.relu(1 - beta * logits) elif loss_type == "ipo": losses = (logits - 1 / (2 * beta)) ** 2 elif loss_type == "dpop": penalty = mx.maximum(mx.zeros_like(policy_chosen_score), reference_chosen_score - policy_chosen_score) losses = -(nn.log_sigmoid(beta * logits) - delta * penalty) else: raise ValueError(f"Unknown loss type: {loss_type}") num_tokens = (num_chosen_tokens + num_rejected_tokens).sum() chosen_reward = beta * mx.mean(policy_chosen_score - reference_chosen_score) rejected_reward = beta * mx.mean(policy_rejected_score - reference_rejected_score) reward = mx.stack([chosen_reward, rejected_reward]) metrics = { 'accuracies': mx.mean((chosen_reward > rejected_reward).astype(mx.float32)), 'margins': mx.mean(chosen_reward - rejected_reward), 'policy_rejected_logps': mx.mean(policy_rejected_score / num_rejected_tokens), 'policy_chosen_logps': mx.mean(policy_chosen_score / num_chosen_tokens), 'rejected_logits_mean': mx.mean(policy_rejected_score), 'chosen_logits_mean': mx.mean(policy_chosen_score) } return mx.mean(losses), reward, num_tokens, metrics def iterate_dpo_batches(dataset, batch_size, max_seq_length, train=False): idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]['chosen'])) step = mx.distributed.init().size() if batch_size % step != 0: raise ValueError("Batch size must be divisible by workers") batch_idx = [idx[i:i+batch_size:step] for i in range(0, len(idx)-batch_size+1, batch_size)] while True: indices = np.random.permutation(len(batch_idx)) if train else range(len(batch_idx)) for i in indices: batch = [dataset[j] for j in batch_idx[i]] # Get and process lengths chosen_lengths = [len(x['chosen']) for x in batch] rejected_lengths = [len(x['rejected']) for x in batch] max_length = min(max(max(chosen_lengths), max(rejected_lengths)), max_seq_length) # Dynamic padding based on batch content max_length_in_batch = max_length chosen_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32) rejected_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32) chosen_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32) rejected_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32) for j in range(batch_size // step): chosen_length = min(chosen_lengths[j], max_seq_length) rejected_length = min(rejected_lengths[j], max_seq_length) chosen_arr[j, :chosen_length] = batch[j]['chosen'][:chosen_length] rejected_arr[j, :rejected_length] = batch[j]['rejected'][:rejected_length] chosen_masks[j, :chosen_length] = 1.0 rejected_masks[j, :rejected_length] = 1.0 yield mx.array(chosen_arr), mx.array(rejected_arr), mx.array(chosen_masks), mx.array(rejected_masks) if not train: break def evaluate_dpo( model, ref_model, dataset, batch_size, num_batches, beta: float, delta: float, max_seq_length, loss_type, loss: callable = dpo_loss ): all_losses = 0 all_rewards = mx.zeros((2,)) all_metrics = None ntokens = 0 index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) for _, batch in zip( index_iterator, iterate_dpo_batches( dataset=dataset, batch_size=batch_size, max_seq_length=max_seq_length, ), ): chosen, rejected, chosen_masks, rejected_masks = batch loss, reward, toks, metrics = loss( model=model, ref_model=ref_model, chosen=chosen, rejected=rejected, chosen_masks=chosen_masks, rejected_masks=rejected_masks, loss_type=loss_type, beta=beta, delta=delta, ) all_losses += loss * toks all_rewards += reward ntokens += toks if all_metrics is None: all_metrics = {k: v * toks for k, v in metrics.items()} else: for k, v in metrics.items(): all_metrics[k] += v * toks mx.eval(all_losses, all_rewards, ntokens) all_losses = mx.distributed.all_sum(all_losses) all_rewards = mx.distributed.all_sum(all_rewards) ntokens = mx.distributed.all_sum(ntokens) all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()} avg_metrics = {k: (v / ntokens).item() for k, v in all_metrics.items()} avg_rewards = (all_rewards / ntokens).tolist() avg_loss = (all_losses / ntokens).item() return avg_loss, avg_rewards, ntokens, avg_metrics def train_dpo( model, ref_model, tokenizer, optimizer, train_dataset, val_dataset, args: DPOTrainingArgs = DPOTrainingArgs(), loss: callable = dpo_loss, training_callback: TrainingCallback = None, loss_type="sigmoid", ): print(f"Starting DPO training..., iters: {args.iters}") world = mx.distributed.init() world_size = world.size() rank = world.rank() if world_size > 1: print(f"Node {rank} of {world_size}") if args.grad_checkpoint: grad_checkpoint(model.layers[0]) state = [model.state, optimizer.state] def step(batch): chosen, rejected, chosen_masks, rejected_masks = batch (lvalue, reward, toks, metrics), grad = loss_value_and_grad( model, ref_model, chosen, rejected, chosen_masks, rejected_masks ) grad = average_gradients(grad) optimizer.update(model, grad) return lvalue, reward, toks, metrics def loss_wrapper(model, ref_model, chosen, rejected, chosen_masks, rejected_masks): return loss( model=model, reference_teacher_model=ref_model, chosen=chosen, rejected=rejected, chosen_masks=chosen_masks, rejected_masks=rejected_masks, beta=args.beta, delta=args.delta, loss_type=loss_type ) loss_value_and_grad = nn.value_and_grad(model, loss_wrapper) losses = 0 rewards = mx.zeros((2,)) n_tokens = 0 steps = 0 trained_tokens = 0 accumulated_metrics = { 'accuracies': 0, 'margins': 0, 'policy_rejected_logps': 0, 'policy_chosen_logps': 0, 'rejected_logits_mean': 0, 'chosen_logits_mean': 0 } start = time.perf_counter() for it, batch in zip( range(1, args.iters + 1), iterate_dpo_batches( dataset=train_dataset, batch_size=args.batch_size, max_seq_length=args.max_seq_length, train=True, ), ): if it == 1 or it % args.steps_per_eval == 0 or it == args.iters: stop = time.perf_counter() val_loss, val_rewards, val_ntokens, val_metrics = evaluate_dpo( model=model, reference_model=ref_model, dataset=val_dataset, batch_size=args.batch_size, num_batches=args.val_batches, max_seq_length=args.max_seq_length, loss=loss, beta=args.beta, delta=args.delta, loss_type=loss_type, ) val_time = time.perf_counter() - stop if rank == 0: print( f"Iter {it}: " f"Val loss {val_loss:.3f}, " f"Val chosen reward {val_rewards[0]:.3f}, " f"Val rejected reward {val_rewards[1]:.3f}, " f"Val accuracy {val_metrics['accuracies']:.3f}, " f"Val margin {val_metrics['margins']:.3f}, " f"Val took {val_time:.3f}s", flush=True, ) if training_callback is not None: training_callback.on_val_loss_report({ "iteration": it, "val_loss": val_loss, "val_chosen_reward": val_rewards[0], "val_rejected_reward": val_rewards[1], **{f"val_{k}": v for k, v in val_metrics.items()}, "val_time": val_time, }) start = time.perf_counter() lvalue, reward, toks, metrics = step(batch) losses += lvalue rewards += reward n_tokens += toks steps += 1 for k, v in metrics.items(): accumulated_metrics[k] += v mx.eval(state, losses, rewards, n_tokens) if it % args.steps_per_report == 0 or it == args.iters: stop = time.perf_counter() train_loss = mx.distributed.all_sum(losses).item() / (steps * world_size) train_rewards = mx.distributed.all_sum(rewards).tolist() train_rewards = [r / (steps * world_size) for r in train_rewards] avg_metrics = {k: v / (steps * world_size) for k, v in accumulated_metrics.items()} n_tokens = mx.distributed.all_sum(n_tokens).item() learning_rate = optimizer.learning_rate.item() it_sec = args.steps_per_report / (stop - start) tokens_sec = float(n_tokens) / (stop - start) trained_tokens += n_tokens peak_mem = mx.metal.get_peak_memory() / 1e9 if rank == 0: print( f"Iter {it}: Train loss {train_loss:.3f}, " f"Chosen reward {train_rewards[0]:.3f}, " f"Rejected reward {train_rewards[1]:.3f}, " f"Accuracy {avg_metrics['accuracies']:.3f}, " f"Margin {avg_metrics['margins']:.3f}, " f"Learning Rate {learning_rate:.3e}, " f"It/sec {it_sec:.3f}, " f"Tokens/sec {tokens_sec:.3f}, " f"Trained Tokens {trained_tokens}, " f"Peak mem {peak_mem:.3f} GB", flush=True, ) if training_callback is not None: train_info = { "iteration": it, "train_loss": train_loss, "train_chosen_reward": train_rewards[0], "train_rejected_reward": train_rewards[1], **{f"train_{k}": v for k, v in avg_metrics.items()}, "learning_rate": learning_rate, "iterations_per_second": it_sec, "tokens_per_second": tokens_sec, "trained_tokens": trained_tokens, "peak_memory": peak_mem, } training_callback.on_train_loss_report(train_info) losses = 0 rewards = mx.zeros((2,)) n_tokens = 0 steps = 0 start = time.perf_counter() # Save adapter weights if it % args.steps_per_save == 0: adapter_weights = dict(tree_flatten(model.trainable_parameters())) mx.save_safetensors(str(args.adapter_file), adapter_weights) checkpoint = ( Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors" ) mx.save_safetensors(str(checkpoint), adapter_weights) print( f"Iter {it}: Saved adapter weights to " f"{args.adapter_file} and {checkpoint}." ) # Save final weights adapter_weights = dict(tree_flatten(model.trainable_parameters())) mx.save_safetensors(str(args.adapter_file), adapter_weights) print(f"Saved final weights to {args.adapter_file}.")