mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-19 11:28:07 +08:00
removing is-reference-free argument
This commit is contained in:
@@ -28,12 +28,6 @@ class DPOTrainingArgs(TrainingArgs):
|
||||
"help": "DPO loss type: 'sigmoid', 'hinge', 'ipo', or 'dpop'."
|
||||
}
|
||||
)
|
||||
is_reference_free: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to use reference-free DPO training."
|
||||
}
|
||||
)
|
||||
delta: float = field(
|
||||
default=50.0,
|
||||
metadata={
|
||||
@@ -50,7 +44,6 @@ class DPOTrainingArgs(TrainingArgs):
|
||||
|
||||
def dpo_loss(
|
||||
model,
|
||||
reference_teacher_model,
|
||||
chosen: mx.array,
|
||||
rejected: mx.array,
|
||||
chosen_masks: mx.array,
|
||||
@@ -58,7 +51,7 @@ def dpo_loss(
|
||||
beta: float,
|
||||
delta: float,
|
||||
loss_type: str = "sigmoid",
|
||||
is_reference_free: bool = False
|
||||
ref_model=None,
|
||||
):
|
||||
def make_predictions(model, x, mask):
|
||||
inputs = x[:, :-1]
|
||||
@@ -84,12 +77,12 @@ def dpo_loss(
|
||||
policy_rejected_score = policy_rejected_scores.sum(-1)
|
||||
|
||||
# Calculate log probabilities for reference model
|
||||
if is_reference_free:
|
||||
if ref_model is None:
|
||||
reference_chosen_score = mx.zeros_like(policy_chosen_score)
|
||||
reference_rejected_score = mx.zeros_like(policy_rejected_score)
|
||||
else:
|
||||
reference_chosen_scores = mx.stop_gradient(make_predictions(reference_teacher_model, chosen, chosen_masks))
|
||||
reference_rejected_scores = mx.stop_gradient(make_predictions(reference_teacher_model, rejected, rejected_masks))
|
||||
reference_chosen_scores = mx.stop_gradient(make_predictions(ref_model, chosen, chosen_masks))
|
||||
reference_rejected_scores = mx.stop_gradient(make_predictions(ref_model, rejected, rejected_masks))
|
||||
if loss_type == "ipo":
|
||||
# ipo uses average log probabilities
|
||||
reference_chosen_score = reference_chosen_scores.sum(-1) / num_chosen_tokens
|
||||
@@ -177,7 +170,7 @@ def iterate_dpo_batches(dataset, batch_size, max_seq_length, train=False):
|
||||
|
||||
def evaluate_dpo(
|
||||
model,
|
||||
reference_model,
|
||||
ref_model,
|
||||
dataset,
|
||||
batch_size,
|
||||
num_batches,
|
||||
@@ -206,7 +199,7 @@ def evaluate_dpo(
|
||||
|
||||
loss, reward, toks, metrics = loss_fn(
|
||||
model=model,
|
||||
reference_teacher_model=reference_model,
|
||||
ref_model=ref_model,
|
||||
chosen=chosen,
|
||||
rejected=rejected,
|
||||
chosen_masks=chosen_masks,
|
||||
@@ -240,7 +233,7 @@ def evaluate_dpo(
|
||||
|
||||
def train_dpo(
|
||||
model,
|
||||
reference_model,
|
||||
ref_model,
|
||||
tokenizer,
|
||||
optimizer,
|
||||
train_dataset,
|
||||
@@ -267,7 +260,7 @@ def train_dpo(
|
||||
|
||||
(loss, reward, toks, metrics), grad = loss_value_and_grad(
|
||||
model,
|
||||
reference_model,
|
||||
ref_model,
|
||||
chosen,
|
||||
rejected,
|
||||
chosen_masks,
|
||||
@@ -289,8 +282,7 @@ def train_dpo(
|
||||
rejected_masks=rejected_masks,
|
||||
beta=args.beta,
|
||||
delta=args.delta,
|
||||
loss_type=loss_type,
|
||||
is_reference_free=args.is_reference_free
|
||||
loss_type=loss_type
|
||||
)
|
||||
|
||||
loss_value_and_grad = nn.value_and_grad(model, loss_wrapper)
|
||||
@@ -324,7 +316,7 @@ def train_dpo(
|
||||
stop = time.perf_counter()
|
||||
val_loss, val_rewards, val_ntokens, val_metrics = evaluate_dpo(
|
||||
model=model,
|
||||
reference_model=reference_model,
|
||||
reference_model=ref_model,
|
||||
dataset=val_dataset,
|
||||
batch_size=args.batch_size,
|
||||
num_batches=args.val_batches,
|
||||
|
Reference in New Issue
Block a user