From a9b760911825b7043a2fe514dded0daefdacd1a0 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Sun, 19 Jan 2025 01:09:43 +0100 Subject: [PATCH] initial commit --- llms/mlx_lm/lora.py | 4 +- llms/mlx_lm/tuner/orpo_trainer.py | 276 ++++++++++++++++++++++++++++++ 2 files changed, 278 insertions(+), 2 deletions(-) create mode 100644 llms/mlx_lm/tuner/orpo_trainer.py diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index ccfd7e01..e16e7741 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -105,8 +105,8 @@ def build_parser(): parser.add_argument( "--training-mode", type=str, - choices=["normal", "dpo"], - help="Training mode: normal or DPO", + choices=["normal", "dpo", "orpo"], + help="Training mode: normal, DPO or ORPO", ) parser.add_argument( "--num-layers", diff --git a/llms/mlx_lm/tuner/orpo_trainer.py b/llms/mlx_lm/tuner/orpo_trainer.py new file mode 100644 index 00000000..3c5b289c --- /dev/null +++ b/llms/mlx_lm/tuner/orpo_trainer.py @@ -0,0 +1,276 @@ +# Copyright © 2024 Apple Inc. + +from dataclasses import dataclass, field +import mlx.core as mx +import mlx.nn as nn + +from dpo_trainer import DPOTrainingArgs, iterate_dpo_batches, train_dpo, TrainingCallback + + +@dataclass +class ORPOTrainingArgs(DPOTrainingArgs): + """ + Training arguments specific to ORPO, extending DPO arguments. + """ + mu: float = field( + default=0.5, + metadata={"help": "ORPO KL divergence weight parameter"} + ) + + +def orpo_loss( + model, + reference_teacher_model, + chosen: mx.array, + rejected: mx.array, + chosen_masks: mx.array, + rejected_masks: mx.array, + beta: float, + delta: float, + mu: float = 0.5, # ORPO hyperparameter for balancing KL divergence + loss_type: str = "sigmoid", + is_reference_free: bool = False +): + """ + Calculate ORPO loss for inputs. + ORPO extends DPO by adding a KL regularization term to prevent overfitting to preferences. + + Args: + model: Policy model + reference_teacher_model: Reference model + chosen: Chosen sequence tokens + rejected: Rejected sequence tokens + chosen_masks: Masks for chosen sequences + rejected_masks: Masks for rejected sequences + beta: Temperature parameter + delta: Margin for DPOP loss type + mu: ORPO hyperparameter for balancing KL divergence (default: 0.5) + loss_type: Loss type ('sigmoid', 'hinge', 'ipo', or 'dpop') + is_reference_free: Whether to use reference-free training + Returns: + Tuple of (loss, reward, num_tokens) + """ + def make_predictions(model, x, mask): + inputs = x[:, :-1] + targets = x[:, 1:] + + logits = model(inputs) + logits = logits.astype(mx.float32) + + return -nn.losses.cross_entropy(logits, targets) * mask[:, :-1] + + num_chosen_tokens = chosen_masks.sum(-1) + num_rejected_tokens = rejected_masks.sum(-1) + + # Calculate log probabilities for policy model + policy_chosen_scores = make_predictions(model, chosen, chosen_masks) + policy_rejected_scores = make_predictions(model, rejected, rejected_masks) + + # Calculate reference model scores + if not is_reference_free: + reference_chosen_scores = mx.stop_gradient(make_predictions(reference_teacher_model, chosen, chosen_masks)) + reference_rejected_scores = mx.stop_gradient(make_predictions(reference_teacher_model, rejected, rejected_masks)) + else: + reference_chosen_scores = mx.zeros_like(policy_chosen_scores) + reference_rejected_scores = mx.zeros_like(policy_rejected_scores) + + # Compute average log probabilities if using IPO loss + if loss_type == "ipo": + policy_chosen_score = policy_chosen_scores.sum(-1) / num_chosen_tokens + policy_rejected_score = policy_rejected_scores.sum(-1) / num_rejected_tokens + reference_chosen_score = reference_chosen_scores.sum(-1) / num_chosen_tokens + reference_rejected_score = reference_rejected_scores.sum(-1) / num_rejected_tokens + else: + policy_chosen_score = policy_chosen_scores.sum(-1) + policy_rejected_score = policy_rejected_scores.sum(-1) + reference_chosen_score = reference_chosen_scores.sum(-1) + reference_rejected_score = reference_rejected_scores.sum(-1) + + # Calculate preference logits + logits = (policy_chosen_score - policy_rejected_score) - (reference_chosen_score - reference_rejected_score) + + # Calculate preference loss based on loss type + if loss_type == "sigmoid": + preference_loss = -nn.log_sigmoid(beta * logits) + elif loss_type == "hinge": + preference_loss = nn.relu(1 - beta * logits) + elif loss_type == "ipo": + preference_loss = (logits - 1 / (2 * beta)) ** 2 + elif loss_type == "dpop": + penalty = mx.maximum(mx.zeros_like(policy_chosen_score), reference_chosen_score - policy_chosen_score) + preference_loss = -(nn.log_sigmoid(beta * logits) - delta * penalty) + else: + raise ValueError(f"Unknown loss type: {loss_type}") + + # Calculate KL divergence term for ORPO + kl_div_chosen = mx.mean((policy_chosen_scores - reference_chosen_scores) ** 2) + kl_div_rejected = mx.mean((policy_rejected_scores - reference_rejected_scores) ** 2) + kl_regularization = mu * (kl_div_chosen + kl_div_rejected) + + # Combine preference loss and KL regularization + loss = mx.mean(preference_loss) + kl_regularization + + num_tokens = (num_chosen_tokens + num_rejected_tokens).sum() + + # Calculate rewards for monitoring + chosen_reward = beta * mx.mean(policy_chosen_score - reference_chosen_score) + rejected_reward = beta * mx.mean(policy_rejected_score - reference_rejected_score) + reward = mx.stack([chosen_reward, rejected_reward]) + + return loss, reward, num_tokens + + +def evaluate_orpo( + model, + reference_model, + dataset, + tokenizer, + batch_size, + num_batches, + beta: float, + delta: float, + mu: float = 0.5, + max_seq_length=2048, + loss_type="sigmoid", + is_reference_free=False, +): + """ + Evaluate model using ORPO metrics. + + Args: + model: Policy model to evaluate + reference_model: Reference model for comparison + dataset: Evaluation dataset + tokenizer: Tokenizer for processing text + batch_size: Batch size for evaluation + num_batches: Number of batches to evaluate (-1 for full dataset) + beta: Temperature parameter + delta: Margin for DPOP loss + mu: ORPO KL divergence weight + max_seq_length: Maximum sequence length + loss_type: Type of loss function + is_reference_free: Whether to use reference-free evaluation + + Returns: + Tuple of (loss, rewards, kl_metrics), where: + - loss is the total ORPO loss + - rewards is [chosen_reward, rejected_reward] + - kl_metrics is [chosen_kl, rejected_kl] + """ + all_losses = 0 + all_rewards = mx.zeros((2,)) # [chosen_reward, rejected_reward] + all_kl_divs = mx.zeros((2,)) # [chosen_kl, rejected_kl] + ntokens = 0 + + def compute_kl_divergence(policy_scores, reference_scores, masks): + """Helper function to compute KL divergence metrics.""" + # Using MSE as a proxy for KL divergence as in the loss function + valid_tokens = masks.sum() + kl_div = ((policy_scores - reference_scores) ** 2 * masks).sum() / valid_tokens + return kl_div + + index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) + + for _, batch in zip( + index_iterator, + iterate_dpo_batches( # Reusing DPO batch iterator + dataset=dataset, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_length=max_seq_length, + ), + ): + chosen, rejected, chosen_masks, rejected_masks = batch + + # Get model predictions + def make_predictions(model, x, mask): + inputs = x[:, :-1] + targets = x[:, 1:] + logits = model(inputs) + logits = logits.astype(mx.float32) + return -nn.losses.cross_entropy(logits, targets) * mask[:, :-1] + + # Get scores for both models + policy_chosen_scores = make_predictions(model, chosen, chosen_masks) + policy_rejected_scores = make_predictions(model, rejected, rejected_masks) + + if not is_reference_free: + reference_chosen_scores = mx.stop_gradient( + make_predictions(reference_model, chosen, chosen_masks) + ) + reference_rejected_scores = mx.stop_gradient( + make_predictions(reference_model, rejected, rejected_masks) + ) + else: + reference_chosen_scores = mx.zeros_like(policy_chosen_scores) + reference_rejected_scores = mx.zeros_like(policy_rejected_scores) + + # Compute KL divergences + chosen_kl = compute_kl_divergence( + policy_chosen_scores, reference_chosen_scores, chosen_masks[:, :-1] + ) + rejected_kl = compute_kl_divergence( + policy_rejected_scores, reference_rejected_scores, rejected_masks[:, :-1] + ) + all_kl_divs += mx.stack([chosen_kl, rejected_kl]) + + # Compute ORPO loss and rewards + loss, reward, toks = orpo_loss( + model=model, + reference_teacher_model=reference_model, + chosen=chosen, + rejected=rejected, + chosen_masks=chosen_masks, + rejected_masks=rejected_masks, + beta=beta, + delta=delta, + mu=mu, + loss_type=loss_type, + is_reference_free=is_reference_free, + ) + + all_losses += loss * toks + all_rewards += reward + ntokens += toks + mx.eval(all_losses, all_rewards, all_kl_divs, ntokens) + + # Aggregate metrics across distributed workers if necessary + all_losses = mx.distributed.all_sum(all_losses) + all_rewards = mx.distributed.all_sum(all_rewards) + all_kl_divs = mx.distributed.all_sum(all_kl_divs) + ntokens = mx.distributed.all_sum(ntokens) + + # Normalize metrics + avg_loss = (all_losses / ntokens).item() + avg_rewards = [r / mx.distributed.init().size() for r in all_rewards.tolist()] + avg_kl_divs = [kl / mx.distributed.init().size() for kl in all_kl_divs.tolist()] + + return avg_loss, avg_rewards, avg_kl_divs + + +def train_orpo( + model, + reference_model, + tokenizer, + optimizer, + train_dataset, + val_dataset, + args: ORPOTrainingArgs = ORPOTrainingArgs(), + training_callback: TrainingCallback = None, +): + """ + Train a model using ORPO (Offline Rejection Preference Optimization). + This function adapts the DPO training loop to use ORPO loss. + """ + return train_dpo( + model=model, + reference_model=reference_model, + tokenizer=tokenizer, + optimizer=optimizer, + train_dataset=train_dataset, + val_dataset=val_dataset, + args=args, + loss_fn=orpo_loss, + training_callback=training_callback, + loss_type=args.loss_type, + ) \ No newline at end of file