removing is-reference-free argument

This commit is contained in:
Goekdeniz-Guelmez 2025-01-31 00:01:43 +01:00
parent b3d6fc38cd
commit b31d9cbb65
3 changed files with 10 additions and 22 deletions

View File

@ -95,7 +95,6 @@ The DPO training accepts the following additional parameters:
- `--beta`: Controls the strength of the DPO loss (default: 0.1) - `--beta`: Controls the strength of the DPO loss (default: 0.1)
- `--dpo-loss-type`: Choose between "sigmoid" (default), "hinge", "ipo", or "dpop" loss functions - `--dpo-loss-type`: Choose between "sigmoid" (default), "hinge", "ipo", or "dpop" loss functions
- `--is-reference-free`: Enable reference-free DPO training
- `--delta`: Margin parameter for hinge loss (default: 50.0) - `--delta`: Margin parameter for hinge loss (default: 50.0)
- `--reference-model-path`: Path to a reference model for DPO training - `--reference-model-path`: Path to a reference model for DPO training

View File

@ -66,7 +66,6 @@ CONFIG_DEFAULTS = {
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}, "lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
"beta": 0.1, "beta": 0.1,
"dpo_loss_type": "sigmoid", "dpo_loss_type": "sigmoid",
"is_reference_free": False,
"delta": 50.0, "delta": 50.0,
"reference_model_path": None, "reference_model_path": None,
"train_bias_only": False, "train_bias_only": False,
@ -176,7 +175,6 @@ def build_parser():
) )
parser.add_argument("--beta", type=float) parser.add_argument("--beta", type=float)
parser.add_argument("--dpo-loss-type", type=str, choices=["sigmoid", "hinge", "ipo", "dpop"]) parser.add_argument("--dpo-loss-type", type=str, choices=["sigmoid", "hinge", "ipo", "dpop"])
parser.add_argument("--is-reference-free", action="store_true")
parser.add_argument("--delta", type=float) parser.add_argument("--delta", type=float)
parser.add_argument("--reference-model-path", type=str) parser.add_argument("--reference-model-path", type=str)
parser.add_argument("--train-bias-only", action="store_true") parser.add_argument("--train-bias-only", action="store_true")
@ -240,7 +238,6 @@ def train_model(
grad_checkpoint=args.grad_checkpoint, grad_checkpoint=args.grad_checkpoint,
beta=args.beta, beta=args.beta,
loss_type=args.dpo_loss_type, loss_type=args.dpo_loss_type,
is_reference_free=args.is_reference_free,
delta=args.delta, delta=args.delta,
reference_model_path=args.reference_model_path reference_model_path=args.reference_model_path
) )

View File

@ -28,12 +28,6 @@ class DPOTrainingArgs(TrainingArgs):
"help": "DPO loss type: 'sigmoid', 'hinge', 'ipo', or 'dpop'." "help": "DPO loss type: 'sigmoid', 'hinge', 'ipo', or 'dpop'."
} }
) )
is_reference_free: bool = field(
default=False,
metadata={
"help": "Whether to use reference-free DPO training."
}
)
delta: float = field( delta: float = field(
default=50.0, default=50.0,
metadata={ metadata={
@ -50,7 +44,6 @@ class DPOTrainingArgs(TrainingArgs):
def dpo_loss( def dpo_loss(
model, model,
reference_teacher_model,
chosen: mx.array, chosen: mx.array,
rejected: mx.array, rejected: mx.array,
chosen_masks: mx.array, chosen_masks: mx.array,
@ -58,7 +51,7 @@ def dpo_loss(
beta: float, beta: float,
delta: float, delta: float,
loss_type: str = "sigmoid", loss_type: str = "sigmoid",
is_reference_free: bool = False ref_model=None,
): ):
def make_predictions(model, x, mask): def make_predictions(model, x, mask):
inputs = x[:, :-1] inputs = x[:, :-1]
@ -84,12 +77,12 @@ def dpo_loss(
policy_rejected_score = policy_rejected_scores.sum(-1) policy_rejected_score = policy_rejected_scores.sum(-1)
# Calculate log probabilities for reference model # Calculate log probabilities for reference model
if is_reference_free: if ref_model is None:
reference_chosen_score = mx.zeros_like(policy_chosen_score) reference_chosen_score = mx.zeros_like(policy_chosen_score)
reference_rejected_score = mx.zeros_like(policy_rejected_score) reference_rejected_score = mx.zeros_like(policy_rejected_score)
else: else:
reference_chosen_scores = mx.stop_gradient(make_predictions(reference_teacher_model, chosen, chosen_masks)) reference_chosen_scores = mx.stop_gradient(make_predictions(ref_model, chosen, chosen_masks))
reference_rejected_scores = mx.stop_gradient(make_predictions(reference_teacher_model, rejected, rejected_masks)) reference_rejected_scores = mx.stop_gradient(make_predictions(ref_model, rejected, rejected_masks))
if loss_type == "ipo": if loss_type == "ipo":
# ipo uses average log probabilities # ipo uses average log probabilities
reference_chosen_score = reference_chosen_scores.sum(-1) / num_chosen_tokens reference_chosen_score = reference_chosen_scores.sum(-1) / num_chosen_tokens
@ -177,7 +170,7 @@ def iterate_dpo_batches(dataset, batch_size, max_seq_length, train=False):
def evaluate_dpo( def evaluate_dpo(
model, model,
reference_model, ref_model,
dataset, dataset,
batch_size, batch_size,
num_batches, num_batches,
@ -206,7 +199,7 @@ def evaluate_dpo(
loss, reward, toks, metrics = loss_fn( loss, reward, toks, metrics = loss_fn(
model=model, model=model,
reference_teacher_model=reference_model, ref_model=ref_model,
chosen=chosen, chosen=chosen,
rejected=rejected, rejected=rejected,
chosen_masks=chosen_masks, chosen_masks=chosen_masks,
@ -240,7 +233,7 @@ def evaluate_dpo(
def train_dpo( def train_dpo(
model, model,
reference_model, ref_model,
tokenizer, tokenizer,
optimizer, optimizer,
train_dataset, train_dataset,
@ -267,7 +260,7 @@ def train_dpo(
(loss, reward, toks, metrics), grad = loss_value_and_grad( (loss, reward, toks, metrics), grad = loss_value_and_grad(
model, model,
reference_model, ref_model,
chosen, chosen,
rejected, rejected,
chosen_masks, chosen_masks,
@ -289,8 +282,7 @@ def train_dpo(
rejected_masks=rejected_masks, rejected_masks=rejected_masks,
beta=args.beta, beta=args.beta,
delta=args.delta, delta=args.delta,
loss_type=loss_type, loss_type=loss_type
is_reference_free=args.is_reference_free
) )
loss_value_and_grad = nn.value_and_grad(model, loss_wrapper) loss_value_and_grad = nn.value_and_grad(model, loss_wrapper)
@ -324,7 +316,7 @@ def train_dpo(
stop = time.perf_counter() stop = time.perf_counter()
val_loss, val_rewards, val_ntokens, val_metrics = evaluate_dpo( val_loss, val_rewards, val_ntokens, val_metrics = evaluate_dpo(
model=model, model=model,
reference_model=reference_model, reference_model=ref_model,
dataset=val_dataset, dataset=val_dataset,
batch_size=args.batch_size, batch_size=args.batch_size,
num_batches=args.val_batches, num_batches=args.val_batches,