removing is-reference-free argument

This commit is contained in:
Goekdeniz-Guelmez
2025-01-31 00:01:43 +01:00
parent b3d6fc38cd
commit b31d9cbb65
3 changed files with 10 additions and 22 deletions

View File

@@ -28,12 +28,6 @@ class DPOTrainingArgs(TrainingArgs):
"help": "DPO loss type: 'sigmoid', 'hinge', 'ipo', or 'dpop'."
}
)
is_reference_free: bool = field(
default=False,
metadata={
"help": "Whether to use reference-free DPO training."
}
)
delta: float = field(
default=50.0,
metadata={
@@ -50,7 +44,6 @@ class DPOTrainingArgs(TrainingArgs):
def dpo_loss(
model,
reference_teacher_model,
chosen: mx.array,
rejected: mx.array,
chosen_masks: mx.array,
@@ -58,7 +51,7 @@ def dpo_loss(
beta: float,
delta: float,
loss_type: str = "sigmoid",
is_reference_free: bool = False
ref_model=None,
):
def make_predictions(model, x, mask):
inputs = x[:, :-1]
@@ -84,12 +77,12 @@ def dpo_loss(
policy_rejected_score = policy_rejected_scores.sum(-1)
# Calculate log probabilities for reference model
if is_reference_free:
if ref_model is None:
reference_chosen_score = mx.zeros_like(policy_chosen_score)
reference_rejected_score = mx.zeros_like(policy_rejected_score)
else:
reference_chosen_scores = mx.stop_gradient(make_predictions(reference_teacher_model, chosen, chosen_masks))
reference_rejected_scores = mx.stop_gradient(make_predictions(reference_teacher_model, rejected, rejected_masks))
reference_chosen_scores = mx.stop_gradient(make_predictions(ref_model, chosen, chosen_masks))
reference_rejected_scores = mx.stop_gradient(make_predictions(ref_model, rejected, rejected_masks))
if loss_type == "ipo":
# ipo uses average log probabilities
reference_chosen_score = reference_chosen_scores.sum(-1) / num_chosen_tokens
@@ -177,7 +170,7 @@ def iterate_dpo_batches(dataset, batch_size, max_seq_length, train=False):
def evaluate_dpo(
model,
reference_model,
ref_model,
dataset,
batch_size,
num_batches,
@@ -206,7 +199,7 @@ def evaluate_dpo(
loss, reward, toks, metrics = loss_fn(
model=model,
reference_teacher_model=reference_model,
ref_model=ref_model,
chosen=chosen,
rejected=rejected,
chosen_masks=chosen_masks,
@@ -240,7 +233,7 @@ def evaluate_dpo(
def train_dpo(
model,
reference_model,
ref_model,
tokenizer,
optimizer,
train_dataset,
@@ -267,7 +260,7 @@ def train_dpo(
(loss, reward, toks, metrics), grad = loss_value_and_grad(
model,
reference_model,
ref_model,
chosen,
rejected,
chosen_masks,
@@ -289,8 +282,7 @@ def train_dpo(
rejected_masks=rejected_masks,
beta=args.beta,
delta=args.delta,
loss_type=loss_type,
is_reference_free=args.is_reference_free
loss_type=loss_type
)
loss_value_and_grad = nn.value_and_grad(model, loss_wrapper)
@@ -324,7 +316,7 @@ def train_dpo(
stop = time.perf_counter()
val_loss, val_rewards, val_ntokens, val_metrics = evaluate_dpo(
model=model,
reference_model=reference_model,
reference_model=ref_model,
dataset=val_dataset,
batch_size=args.batch_size,
num_batches=args.val_batches,