From b31d9cbb65ae4e94f2fb48435b75ece4157895a9 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Fri, 31 Jan 2025 00:01:43 +0100 Subject: [PATCH] removing is-reference-free argument --- llms/mlx_lm/LORA.md | 1 - llms/mlx_lm/lora.py | 3 --- llms/mlx_lm/tuner/dpo_trainer.py | 28 ++++++++++------------------ 3 files changed, 10 insertions(+), 22 deletions(-) diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md index 6dd8197d..70c29b75 100644 --- a/llms/mlx_lm/LORA.md +++ b/llms/mlx_lm/LORA.md @@ -95,7 +95,6 @@ The DPO training accepts the following additional parameters: - `--beta`: Controls the strength of the DPO loss (default: 0.1) - `--dpo-loss-type`: Choose between "sigmoid" (default), "hinge", "ipo", or "dpop" loss functions -- `--is-reference-free`: Enable reference-free DPO training - `--delta`: Margin parameter for hinge loss (default: 50.0) - `--reference-model-path`: Path to a reference model for DPO training diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index dcf94bad..f4c17c78 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -66,7 +66,6 @@ CONFIG_DEFAULTS = { "lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}, "beta": 0.1, "dpo_loss_type": "sigmoid", - "is_reference_free": False, "delta": 50.0, "reference_model_path": None, "train_bias_only": False, @@ -176,7 +175,6 @@ def build_parser(): ) parser.add_argument("--beta", type=float) parser.add_argument("--dpo-loss-type", type=str, choices=["sigmoid", "hinge", "ipo", "dpop"]) - parser.add_argument("--is-reference-free", action="store_true") parser.add_argument("--delta", type=float) parser.add_argument("--reference-model-path", type=str) parser.add_argument("--train-bias-only", action="store_true") @@ -240,7 +238,6 @@ def train_model( grad_checkpoint=args.grad_checkpoint, beta=args.beta, loss_type=args.dpo_loss_type, - is_reference_free=args.is_reference_free, delta=args.delta, reference_model_path=args.reference_model_path ) diff --git a/llms/mlx_lm/tuner/dpo_trainer.py b/llms/mlx_lm/tuner/dpo_trainer.py index ed955e01..1a8ae42f 100644 --- a/llms/mlx_lm/tuner/dpo_trainer.py +++ b/llms/mlx_lm/tuner/dpo_trainer.py @@ -28,12 +28,6 @@ class DPOTrainingArgs(TrainingArgs): "help": "DPO loss type: 'sigmoid', 'hinge', 'ipo', or 'dpop'." } ) - is_reference_free: bool = field( - default=False, - metadata={ - "help": "Whether to use reference-free DPO training." - } - ) delta: float = field( default=50.0, metadata={ @@ -50,7 +44,6 @@ class DPOTrainingArgs(TrainingArgs): def dpo_loss( model, - reference_teacher_model, chosen: mx.array, rejected: mx.array, chosen_masks: mx.array, @@ -58,7 +51,7 @@ def dpo_loss( beta: float, delta: float, loss_type: str = "sigmoid", - is_reference_free: bool = False + ref_model=None, ): def make_predictions(model, x, mask): inputs = x[:, :-1] @@ -84,12 +77,12 @@ def dpo_loss( policy_rejected_score = policy_rejected_scores.sum(-1) # Calculate log probabilities for reference model - if is_reference_free: + if ref_model is None: reference_chosen_score = mx.zeros_like(policy_chosen_score) reference_rejected_score = mx.zeros_like(policy_rejected_score) else: - reference_chosen_scores = mx.stop_gradient(make_predictions(reference_teacher_model, chosen, chosen_masks)) - reference_rejected_scores = mx.stop_gradient(make_predictions(reference_teacher_model, rejected, rejected_masks)) + reference_chosen_scores = mx.stop_gradient(make_predictions(ref_model, chosen, chosen_masks)) + reference_rejected_scores = mx.stop_gradient(make_predictions(ref_model, rejected, rejected_masks)) if loss_type == "ipo": # ipo uses average log probabilities reference_chosen_score = reference_chosen_scores.sum(-1) / num_chosen_tokens @@ -177,7 +170,7 @@ def iterate_dpo_batches(dataset, batch_size, max_seq_length, train=False): def evaluate_dpo( model, - reference_model, + ref_model, dataset, batch_size, num_batches, @@ -206,7 +199,7 @@ def evaluate_dpo( loss, reward, toks, metrics = loss_fn( model=model, - reference_teacher_model=reference_model, + ref_model=ref_model, chosen=chosen, rejected=rejected, chosen_masks=chosen_masks, @@ -240,7 +233,7 @@ def evaluate_dpo( def train_dpo( model, - reference_model, + ref_model, tokenizer, optimizer, train_dataset, @@ -267,7 +260,7 @@ def train_dpo( (loss, reward, toks, metrics), grad = loss_value_and_grad( model, - reference_model, + ref_model, chosen, rejected, chosen_masks, @@ -289,8 +282,7 @@ def train_dpo( rejected_masks=rejected_masks, beta=args.beta, delta=args.delta, - loss_type=loss_type, - is_reference_free=args.is_reference_free + loss_type=loss_type ) loss_value_and_grad = nn.value_and_grad(model, loss_wrapper) @@ -324,7 +316,7 @@ def train_dpo( stop = time.perf_counter() val_loss, val_rewards, val_ntokens, val_metrics = evaluate_dpo( model=model, - reference_model=reference_model, + reference_model=ref_model, dataset=val_dataset, batch_size=args.batch_size, num_batches=args.val_batches,