cleaning up some namings

This commit is contained in:
Goekdeniz-Guelmez 2025-01-31 21:27:59 +01:00
parent b379359385
commit 5998272ec2
2 changed files with 40 additions and 19 deletions

View File

@ -67,8 +67,7 @@ CONFIG_DEFAULTS = {
"beta": 0.1,
"dpo_loss_type": "sigmoid",
"delta": 50.0,
"reference_model_path": None,
"train_bias_only": False,
"reference_model_path": None
}
@ -173,12 +172,35 @@ def build_parser():
help="Use gradient checkpointing to reduce memory use.",
default=None,
)
parser.add_argument("--beta", type=float)
parser.add_argument("--dpo-loss-type", type=str, choices=["sigmoid", "hinge", "ipo", "dpop"])
parser.add_argument("--delta", type=float)
parser.add_argument("--reference-model-path", type=str)
parser.add_argument("--train-bias-only", action="store_true")
parser.add_argument("--seed", type=int, help="The PRNG seed")
# DPO args
parser.add_argument(
"--beta",
type=float,
help="Temperature parameter for DPO training.",
default=0.1
)
parser.add_argument(
"--dpo-loss-type",
type=str,
help="DPO loss type: 'sigmoid', 'hinge', 'ipo', or 'dpop'.",
choices=["sigmoid", "hinge", "ipo", "dpop"],
default="sigmoid"
)
parser.add_argument(
"--delta",
type=float,
help="Delta parameter for DPOP loss type.",
default=50.0
)
parser.add_argument(
"--reference-model-path",
type=str,
help="Path to reference model weights. If None, uses the same model.",
default=None
)
return parser

View File

@ -12,7 +12,6 @@ import mlx.nn as nn
import numpy as np
from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten
from ..generate import generate
from .trainer import TrainingCallback, grad_checkpoint, TrainingArgs
@ -100,7 +99,6 @@ def dpo_loss(
elif loss_type == "ipo":
losses = (logits - 1 / (2 * beta)) ** 2
elif loss_type == "dpop":
delta = 50
penalty = mx.maximum(mx.zeros_like(policy_chosen_score), reference_chosen_score - policy_chosen_score)
losses = -(nn.log_sigmoid(beta * logits) - delta * penalty)
else:
@ -178,7 +176,7 @@ def evaluate_dpo(
delta: float,
max_seq_length,
loss_type,
loss_fn: callable = dpo_loss
loss: callable = dpo_loss
):
all_losses = 0
all_rewards = mx.zeros((2,))
@ -197,7 +195,7 @@ def evaluate_dpo(
):
chosen, rejected, chosen_masks, rejected_masks = batch
loss, reward, toks, metrics = loss_fn(
loss, reward, toks, metrics = loss(
model=model,
ref_model=ref_model,
chosen=chosen,
@ -239,7 +237,7 @@ def train_dpo(
train_dataset,
val_dataset,
args: DPOTrainingArgs = DPOTrainingArgs(),
loss_fn: callable = dpo_loss,
loss: callable = dpo_loss,
training_callback: TrainingCallback = None,
loss_type="sigmoid",
):
@ -258,7 +256,7 @@ def train_dpo(
def step(batch):
chosen, rejected, chosen_masks, rejected_masks = batch
(loss, reward, toks, metrics), grad = loss_value_and_grad(
(lvalue, reward, toks, metrics), grad = loss_value_and_grad(
model,
ref_model,
chosen,
@ -270,10 +268,10 @@ def train_dpo(
grad = average_gradients(grad)
optimizer.update(model, grad)
return loss, reward, toks, metrics
return lvalue, reward, toks, metrics
def loss_wrapper(model, ref_model, chosen, rejected, chosen_masks, rejected_masks):
return loss_fn(
return loss(
model=model,
reference_teacher_model=ref_model,
chosen=chosen,
@ -311,7 +309,6 @@ def train_dpo(
train=True,
),
):
# Report validation loss if needed
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
stop = time.perf_counter()
val_loss, val_rewards, val_ntokens, val_metrics = evaluate_dpo(
@ -321,7 +318,7 @@ def train_dpo(
batch_size=args.batch_size,
num_batches=args.val_batches,
max_seq_length=args.max_seq_length,
loss_fn=loss_fn,
loss=loss,
beta=args.beta,
delta=args.delta,
loss_type=loss_type,
@ -351,13 +348,15 @@ def train_dpo(
start = time.perf_counter()
loss, reward, toks, metrics = step(batch)
losses += loss
lvalue, reward, toks, metrics = step(batch)
losses += lvalue
rewards += reward
n_tokens += toks
steps += 1
for k, v in metrics.items():
accumulated_metrics[k] += v
mx.eval(state, losses, rewards, n_tokens)
if it % args.steps_per_report == 0 or it == args.iters: