diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index f4c17c78..fdc400c6 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -67,8 +67,7 @@ CONFIG_DEFAULTS = { "beta": 0.1, "dpo_loss_type": "sigmoid", "delta": 50.0, - "reference_model_path": None, - "train_bias_only": False, + "reference_model_path": None } @@ -173,12 +172,35 @@ def build_parser(): help="Use gradient checkpointing to reduce memory use.", default=None, ) - parser.add_argument("--beta", type=float) - parser.add_argument("--dpo-loss-type", type=str, choices=["sigmoid", "hinge", "ipo", "dpop"]) - parser.add_argument("--delta", type=float) - parser.add_argument("--reference-model-path", type=str) - parser.add_argument("--train-bias-only", action="store_true") parser.add_argument("--seed", type=int, help="The PRNG seed") + + # DPO args + parser.add_argument( + "--beta", + type=float, + help="Temperature parameter for DPO training.", + default=0.1 + ) + parser.add_argument( + "--dpo-loss-type", + type=str, + help="DPO loss type: 'sigmoid', 'hinge', 'ipo', or 'dpop'.", + choices=["sigmoid", "hinge", "ipo", "dpop"], + default="sigmoid" + ) + parser.add_argument( + "--delta", + type=float, + help="Delta parameter for DPOP loss type.", + default=50.0 + ) + parser.add_argument( + "--reference-model-path", + type=str, + help="Path to reference model weights. If None, uses the same model.", + default=None + ) + return parser diff --git a/llms/mlx_lm/tuner/dpo_trainer.py b/llms/mlx_lm/tuner/dpo_trainer.py index 9e790273..8979ec0d 100644 --- a/llms/mlx_lm/tuner/dpo_trainer.py +++ b/llms/mlx_lm/tuner/dpo_trainer.py @@ -12,7 +12,6 @@ import mlx.nn as nn import numpy as np from mlx.nn.utils import average_gradients from mlx.utils import tree_flatten -from ..generate import generate from .trainer import TrainingCallback, grad_checkpoint, TrainingArgs @@ -100,7 +99,6 @@ def dpo_loss( elif loss_type == "ipo": losses = (logits - 1 / (2 * beta)) ** 2 elif loss_type == "dpop": - delta = 50 penalty = mx.maximum(mx.zeros_like(policy_chosen_score), reference_chosen_score - policy_chosen_score) losses = -(nn.log_sigmoid(beta * logits) - delta * penalty) else: @@ -178,7 +176,7 @@ def evaluate_dpo( delta: float, max_seq_length, loss_type, - loss_fn: callable = dpo_loss + loss: callable = dpo_loss ): all_losses = 0 all_rewards = mx.zeros((2,)) @@ -197,7 +195,7 @@ def evaluate_dpo( ): chosen, rejected, chosen_masks, rejected_masks = batch - loss, reward, toks, metrics = loss_fn( + loss, reward, toks, metrics = loss( model=model, ref_model=ref_model, chosen=chosen, @@ -239,7 +237,7 @@ def train_dpo( train_dataset, val_dataset, args: DPOTrainingArgs = DPOTrainingArgs(), - loss_fn: callable = dpo_loss, + loss: callable = dpo_loss, training_callback: TrainingCallback = None, loss_type="sigmoid", ): @@ -258,7 +256,7 @@ def train_dpo( def step(batch): chosen, rejected, chosen_masks, rejected_masks = batch - (loss, reward, toks, metrics), grad = loss_value_and_grad( + (lvalue, reward, toks, metrics), grad = loss_value_and_grad( model, ref_model, chosen, @@ -270,10 +268,10 @@ def train_dpo( grad = average_gradients(grad) optimizer.update(model, grad) - return loss, reward, toks, metrics + return lvalue, reward, toks, metrics def loss_wrapper(model, ref_model, chosen, rejected, chosen_masks, rejected_masks): - return loss_fn( + return loss( model=model, reference_teacher_model=ref_model, chosen=chosen, @@ -311,7 +309,6 @@ def train_dpo( train=True, ), ): - # Report validation loss if needed if it == 1 or it % args.steps_per_eval == 0 or it == args.iters: stop = time.perf_counter() val_loss, val_rewards, val_ntokens, val_metrics = evaluate_dpo( @@ -321,7 +318,7 @@ def train_dpo( batch_size=args.batch_size, num_batches=args.val_batches, max_seq_length=args.max_seq_length, - loss_fn=loss_fn, + loss=loss, beta=args.beta, delta=args.delta, loss_type=loss_type, @@ -351,13 +348,15 @@ def train_dpo( start = time.perf_counter() - loss, reward, toks, metrics = step(batch) - losses += loss + lvalue, reward, toks, metrics = step(batch) + losses += lvalue rewards += reward n_tokens += toks steps += 1 + for k, v in metrics.items(): accumulated_metrics[k] += v + mx.eval(state, losses, rewards, n_tokens) if it % args.steps_per_report == 0 or it == args.iters: