mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-11 03:36:42 +08:00
removing is-reference-free argument
This commit is contained in:
parent
b3d6fc38cd
commit
b31d9cbb65
@ -95,7 +95,6 @@ The DPO training accepts the following additional parameters:
|
|||||||
|
|
||||||
- `--beta`: Controls the strength of the DPO loss (default: 0.1)
|
- `--beta`: Controls the strength of the DPO loss (default: 0.1)
|
||||||
- `--dpo-loss-type`: Choose between "sigmoid" (default), "hinge", "ipo", or "dpop" loss functions
|
- `--dpo-loss-type`: Choose between "sigmoid" (default), "hinge", "ipo", or "dpop" loss functions
|
||||||
- `--is-reference-free`: Enable reference-free DPO training
|
|
||||||
- `--delta`: Margin parameter for hinge loss (default: 50.0)
|
- `--delta`: Margin parameter for hinge loss (default: 50.0)
|
||||||
- `--reference-model-path`: Path to a reference model for DPO training
|
- `--reference-model-path`: Path to a reference model for DPO training
|
||||||
|
|
||||||
|
@ -66,7 +66,6 @@ CONFIG_DEFAULTS = {
|
|||||||
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
||||||
"beta": 0.1,
|
"beta": 0.1,
|
||||||
"dpo_loss_type": "sigmoid",
|
"dpo_loss_type": "sigmoid",
|
||||||
"is_reference_free": False,
|
|
||||||
"delta": 50.0,
|
"delta": 50.0,
|
||||||
"reference_model_path": None,
|
"reference_model_path": None,
|
||||||
"train_bias_only": False,
|
"train_bias_only": False,
|
||||||
@ -176,7 +175,6 @@ def build_parser():
|
|||||||
)
|
)
|
||||||
parser.add_argument("--beta", type=float)
|
parser.add_argument("--beta", type=float)
|
||||||
parser.add_argument("--dpo-loss-type", type=str, choices=["sigmoid", "hinge", "ipo", "dpop"])
|
parser.add_argument("--dpo-loss-type", type=str, choices=["sigmoid", "hinge", "ipo", "dpop"])
|
||||||
parser.add_argument("--is-reference-free", action="store_true")
|
|
||||||
parser.add_argument("--delta", type=float)
|
parser.add_argument("--delta", type=float)
|
||||||
parser.add_argument("--reference-model-path", type=str)
|
parser.add_argument("--reference-model-path", type=str)
|
||||||
parser.add_argument("--train-bias-only", action="store_true")
|
parser.add_argument("--train-bias-only", action="store_true")
|
||||||
@ -240,7 +238,6 @@ def train_model(
|
|||||||
grad_checkpoint=args.grad_checkpoint,
|
grad_checkpoint=args.grad_checkpoint,
|
||||||
beta=args.beta,
|
beta=args.beta,
|
||||||
loss_type=args.dpo_loss_type,
|
loss_type=args.dpo_loss_type,
|
||||||
is_reference_free=args.is_reference_free,
|
|
||||||
delta=args.delta,
|
delta=args.delta,
|
||||||
reference_model_path=args.reference_model_path
|
reference_model_path=args.reference_model_path
|
||||||
)
|
)
|
||||||
|
@ -28,12 +28,6 @@ class DPOTrainingArgs(TrainingArgs):
|
|||||||
"help": "DPO loss type: 'sigmoid', 'hinge', 'ipo', or 'dpop'."
|
"help": "DPO loss type: 'sigmoid', 'hinge', 'ipo', or 'dpop'."
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
is_reference_free: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={
|
|
||||||
"help": "Whether to use reference-free DPO training."
|
|
||||||
}
|
|
||||||
)
|
|
||||||
delta: float = field(
|
delta: float = field(
|
||||||
default=50.0,
|
default=50.0,
|
||||||
metadata={
|
metadata={
|
||||||
@ -50,7 +44,6 @@ class DPOTrainingArgs(TrainingArgs):
|
|||||||
|
|
||||||
def dpo_loss(
|
def dpo_loss(
|
||||||
model,
|
model,
|
||||||
reference_teacher_model,
|
|
||||||
chosen: mx.array,
|
chosen: mx.array,
|
||||||
rejected: mx.array,
|
rejected: mx.array,
|
||||||
chosen_masks: mx.array,
|
chosen_masks: mx.array,
|
||||||
@ -58,7 +51,7 @@ def dpo_loss(
|
|||||||
beta: float,
|
beta: float,
|
||||||
delta: float,
|
delta: float,
|
||||||
loss_type: str = "sigmoid",
|
loss_type: str = "sigmoid",
|
||||||
is_reference_free: bool = False
|
ref_model=None,
|
||||||
):
|
):
|
||||||
def make_predictions(model, x, mask):
|
def make_predictions(model, x, mask):
|
||||||
inputs = x[:, :-1]
|
inputs = x[:, :-1]
|
||||||
@ -84,12 +77,12 @@ def dpo_loss(
|
|||||||
policy_rejected_score = policy_rejected_scores.sum(-1)
|
policy_rejected_score = policy_rejected_scores.sum(-1)
|
||||||
|
|
||||||
# Calculate log probabilities for reference model
|
# Calculate log probabilities for reference model
|
||||||
if is_reference_free:
|
if ref_model is None:
|
||||||
reference_chosen_score = mx.zeros_like(policy_chosen_score)
|
reference_chosen_score = mx.zeros_like(policy_chosen_score)
|
||||||
reference_rejected_score = mx.zeros_like(policy_rejected_score)
|
reference_rejected_score = mx.zeros_like(policy_rejected_score)
|
||||||
else:
|
else:
|
||||||
reference_chosen_scores = mx.stop_gradient(make_predictions(reference_teacher_model, chosen, chosen_masks))
|
reference_chosen_scores = mx.stop_gradient(make_predictions(ref_model, chosen, chosen_masks))
|
||||||
reference_rejected_scores = mx.stop_gradient(make_predictions(reference_teacher_model, rejected, rejected_masks))
|
reference_rejected_scores = mx.stop_gradient(make_predictions(ref_model, rejected, rejected_masks))
|
||||||
if loss_type == "ipo":
|
if loss_type == "ipo":
|
||||||
# ipo uses average log probabilities
|
# ipo uses average log probabilities
|
||||||
reference_chosen_score = reference_chosen_scores.sum(-1) / num_chosen_tokens
|
reference_chosen_score = reference_chosen_scores.sum(-1) / num_chosen_tokens
|
||||||
@ -177,7 +170,7 @@ def iterate_dpo_batches(dataset, batch_size, max_seq_length, train=False):
|
|||||||
|
|
||||||
def evaluate_dpo(
|
def evaluate_dpo(
|
||||||
model,
|
model,
|
||||||
reference_model,
|
ref_model,
|
||||||
dataset,
|
dataset,
|
||||||
batch_size,
|
batch_size,
|
||||||
num_batches,
|
num_batches,
|
||||||
@ -206,7 +199,7 @@ def evaluate_dpo(
|
|||||||
|
|
||||||
loss, reward, toks, metrics = loss_fn(
|
loss, reward, toks, metrics = loss_fn(
|
||||||
model=model,
|
model=model,
|
||||||
reference_teacher_model=reference_model,
|
ref_model=ref_model,
|
||||||
chosen=chosen,
|
chosen=chosen,
|
||||||
rejected=rejected,
|
rejected=rejected,
|
||||||
chosen_masks=chosen_masks,
|
chosen_masks=chosen_masks,
|
||||||
@ -240,7 +233,7 @@ def evaluate_dpo(
|
|||||||
|
|
||||||
def train_dpo(
|
def train_dpo(
|
||||||
model,
|
model,
|
||||||
reference_model,
|
ref_model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
optimizer,
|
optimizer,
|
||||||
train_dataset,
|
train_dataset,
|
||||||
@ -267,7 +260,7 @@ def train_dpo(
|
|||||||
|
|
||||||
(loss, reward, toks, metrics), grad = loss_value_and_grad(
|
(loss, reward, toks, metrics), grad = loss_value_and_grad(
|
||||||
model,
|
model,
|
||||||
reference_model,
|
ref_model,
|
||||||
chosen,
|
chosen,
|
||||||
rejected,
|
rejected,
|
||||||
chosen_masks,
|
chosen_masks,
|
||||||
@ -289,8 +282,7 @@ def train_dpo(
|
|||||||
rejected_masks=rejected_masks,
|
rejected_masks=rejected_masks,
|
||||||
beta=args.beta,
|
beta=args.beta,
|
||||||
delta=args.delta,
|
delta=args.delta,
|
||||||
loss_type=loss_type,
|
loss_type=loss_type
|
||||||
is_reference_free=args.is_reference_free
|
|
||||||
)
|
)
|
||||||
|
|
||||||
loss_value_and_grad = nn.value_and_grad(model, loss_wrapper)
|
loss_value_and_grad = nn.value_and_grad(model, loss_wrapper)
|
||||||
@ -324,7 +316,7 @@ def train_dpo(
|
|||||||
stop = time.perf_counter()
|
stop = time.perf_counter()
|
||||||
val_loss, val_rewards, val_ntokens, val_metrics = evaluate_dpo(
|
val_loss, val_rewards, val_ntokens, val_metrics = evaluate_dpo(
|
||||||
model=model,
|
model=model,
|
||||||
reference_model=reference_model,
|
reference_model=ref_model,
|
||||||
dataset=val_dataset,
|
dataset=val_dataset,
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
num_batches=args.val_batches,
|
num_batches=args.val_batches,
|
||||||
|
Loading…
Reference in New Issue
Block a user