mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-11 03:36:42 +08:00
removing is-reference-free argument
This commit is contained in:
parent
b3d6fc38cd
commit
b31d9cbb65
@ -95,7 +95,6 @@ The DPO training accepts the following additional parameters:
|
||||
|
||||
- `--beta`: Controls the strength of the DPO loss (default: 0.1)
|
||||
- `--dpo-loss-type`: Choose between "sigmoid" (default), "hinge", "ipo", or "dpop" loss functions
|
||||
- `--is-reference-free`: Enable reference-free DPO training
|
||||
- `--delta`: Margin parameter for hinge loss (default: 50.0)
|
||||
- `--reference-model-path`: Path to a reference model for DPO training
|
||||
|
||||
|
@ -66,7 +66,6 @@ CONFIG_DEFAULTS = {
|
||||
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
||||
"beta": 0.1,
|
||||
"dpo_loss_type": "sigmoid",
|
||||
"is_reference_free": False,
|
||||
"delta": 50.0,
|
||||
"reference_model_path": None,
|
||||
"train_bias_only": False,
|
||||
@ -176,7 +175,6 @@ def build_parser():
|
||||
)
|
||||
parser.add_argument("--beta", type=float)
|
||||
parser.add_argument("--dpo-loss-type", type=str, choices=["sigmoid", "hinge", "ipo", "dpop"])
|
||||
parser.add_argument("--is-reference-free", action="store_true")
|
||||
parser.add_argument("--delta", type=float)
|
||||
parser.add_argument("--reference-model-path", type=str)
|
||||
parser.add_argument("--train-bias-only", action="store_true")
|
||||
@ -240,7 +238,6 @@ def train_model(
|
||||
grad_checkpoint=args.grad_checkpoint,
|
||||
beta=args.beta,
|
||||
loss_type=args.dpo_loss_type,
|
||||
is_reference_free=args.is_reference_free,
|
||||
delta=args.delta,
|
||||
reference_model_path=args.reference_model_path
|
||||
)
|
||||
|
@ -28,12 +28,6 @@ class DPOTrainingArgs(TrainingArgs):
|
||||
"help": "DPO loss type: 'sigmoid', 'hinge', 'ipo', or 'dpop'."
|
||||
}
|
||||
)
|
||||
is_reference_free: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to use reference-free DPO training."
|
||||
}
|
||||
)
|
||||
delta: float = field(
|
||||
default=50.0,
|
||||
metadata={
|
||||
@ -50,7 +44,6 @@ class DPOTrainingArgs(TrainingArgs):
|
||||
|
||||
def dpo_loss(
|
||||
model,
|
||||
reference_teacher_model,
|
||||
chosen: mx.array,
|
||||
rejected: mx.array,
|
||||
chosen_masks: mx.array,
|
||||
@ -58,7 +51,7 @@ def dpo_loss(
|
||||
beta: float,
|
||||
delta: float,
|
||||
loss_type: str = "sigmoid",
|
||||
is_reference_free: bool = False
|
||||
ref_model=None,
|
||||
):
|
||||
def make_predictions(model, x, mask):
|
||||
inputs = x[:, :-1]
|
||||
@ -84,12 +77,12 @@ def dpo_loss(
|
||||
policy_rejected_score = policy_rejected_scores.sum(-1)
|
||||
|
||||
# Calculate log probabilities for reference model
|
||||
if is_reference_free:
|
||||
if ref_model is None:
|
||||
reference_chosen_score = mx.zeros_like(policy_chosen_score)
|
||||
reference_rejected_score = mx.zeros_like(policy_rejected_score)
|
||||
else:
|
||||
reference_chosen_scores = mx.stop_gradient(make_predictions(reference_teacher_model, chosen, chosen_masks))
|
||||
reference_rejected_scores = mx.stop_gradient(make_predictions(reference_teacher_model, rejected, rejected_masks))
|
||||
reference_chosen_scores = mx.stop_gradient(make_predictions(ref_model, chosen, chosen_masks))
|
||||
reference_rejected_scores = mx.stop_gradient(make_predictions(ref_model, rejected, rejected_masks))
|
||||
if loss_type == "ipo":
|
||||
# ipo uses average log probabilities
|
||||
reference_chosen_score = reference_chosen_scores.sum(-1) / num_chosen_tokens
|
||||
@ -177,7 +170,7 @@ def iterate_dpo_batches(dataset, batch_size, max_seq_length, train=False):
|
||||
|
||||
def evaluate_dpo(
|
||||
model,
|
||||
reference_model,
|
||||
ref_model,
|
||||
dataset,
|
||||
batch_size,
|
||||
num_batches,
|
||||
@ -206,7 +199,7 @@ def evaluate_dpo(
|
||||
|
||||
loss, reward, toks, metrics = loss_fn(
|
||||
model=model,
|
||||
reference_teacher_model=reference_model,
|
||||
ref_model=ref_model,
|
||||
chosen=chosen,
|
||||
rejected=rejected,
|
||||
chosen_masks=chosen_masks,
|
||||
@ -240,7 +233,7 @@ def evaluate_dpo(
|
||||
|
||||
def train_dpo(
|
||||
model,
|
||||
reference_model,
|
||||
ref_model,
|
||||
tokenizer,
|
||||
optimizer,
|
||||
train_dataset,
|
||||
@ -267,7 +260,7 @@ def train_dpo(
|
||||
|
||||
(loss, reward, toks, metrics), grad = loss_value_and_grad(
|
||||
model,
|
||||
reference_model,
|
||||
ref_model,
|
||||
chosen,
|
||||
rejected,
|
||||
chosen_masks,
|
||||
@ -289,8 +282,7 @@ def train_dpo(
|
||||
rejected_masks=rejected_masks,
|
||||
beta=args.beta,
|
||||
delta=args.delta,
|
||||
loss_type=loss_type,
|
||||
is_reference_free=args.is_reference_free
|
||||
loss_type=loss_type
|
||||
)
|
||||
|
||||
loss_value_and_grad = nn.value_and_grad(model, loss_wrapper)
|
||||
@ -324,7 +316,7 @@ def train_dpo(
|
||||
stop = time.perf_counter()
|
||||
val_loss, val_rewards, val_ntokens, val_metrics = evaluate_dpo(
|
||||
model=model,
|
||||
reference_model=reference_model,
|
||||
reference_model=ref_model,
|
||||
dataset=val_dataset,
|
||||
batch_size=args.batch_size,
|
||||
num_batches=args.val_batches,
|
||||
|
Loading…
Reference in New Issue
Block a user