mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-28 12:13:25 +08:00
cleaning up some namings
This commit is contained in:
parent
b379359385
commit
5998272ec2
@ -67,8 +67,7 @@ CONFIG_DEFAULTS = {
|
|||||||
"beta": 0.1,
|
"beta": 0.1,
|
||||||
"dpo_loss_type": "sigmoid",
|
"dpo_loss_type": "sigmoid",
|
||||||
"delta": 50.0,
|
"delta": 50.0,
|
||||||
"reference_model_path": None,
|
"reference_model_path": None
|
||||||
"train_bias_only": False,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -173,12 +172,35 @@ def build_parser():
|
|||||||
help="Use gradient checkpointing to reduce memory use.",
|
help="Use gradient checkpointing to reduce memory use.",
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
parser.add_argument("--beta", type=float)
|
|
||||||
parser.add_argument("--dpo-loss-type", type=str, choices=["sigmoid", "hinge", "ipo", "dpop"])
|
|
||||||
parser.add_argument("--delta", type=float)
|
|
||||||
parser.add_argument("--reference-model-path", type=str)
|
|
||||||
parser.add_argument("--train-bias-only", action="store_true")
|
|
||||||
parser.add_argument("--seed", type=int, help="The PRNG seed")
|
parser.add_argument("--seed", type=int, help="The PRNG seed")
|
||||||
|
|
||||||
|
# DPO args
|
||||||
|
parser.add_argument(
|
||||||
|
"--beta",
|
||||||
|
type=float,
|
||||||
|
help="Temperature parameter for DPO training.",
|
||||||
|
default=0.1
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dpo-loss-type",
|
||||||
|
type=str,
|
||||||
|
help="DPO loss type: 'sigmoid', 'hinge', 'ipo', or 'dpop'.",
|
||||||
|
choices=["sigmoid", "hinge", "ipo", "dpop"],
|
||||||
|
default="sigmoid"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--delta",
|
||||||
|
type=float,
|
||||||
|
help="Delta parameter for DPOP loss type.",
|
||||||
|
default=50.0
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--reference-model-path",
|
||||||
|
type=str,
|
||||||
|
help="Path to reference model weights. If None, uses the same model.",
|
||||||
|
default=None
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,7 +12,6 @@ import mlx.nn as nn
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from mlx.nn.utils import average_gradients
|
from mlx.nn.utils import average_gradients
|
||||||
from mlx.utils import tree_flatten
|
from mlx.utils import tree_flatten
|
||||||
from ..generate import generate
|
|
||||||
from .trainer import TrainingCallback, grad_checkpoint, TrainingArgs
|
from .trainer import TrainingCallback, grad_checkpoint, TrainingArgs
|
||||||
|
|
||||||
|
|
||||||
@ -100,7 +99,6 @@ def dpo_loss(
|
|||||||
elif loss_type == "ipo":
|
elif loss_type == "ipo":
|
||||||
losses = (logits - 1 / (2 * beta)) ** 2
|
losses = (logits - 1 / (2 * beta)) ** 2
|
||||||
elif loss_type == "dpop":
|
elif loss_type == "dpop":
|
||||||
delta = 50
|
|
||||||
penalty = mx.maximum(mx.zeros_like(policy_chosen_score), reference_chosen_score - policy_chosen_score)
|
penalty = mx.maximum(mx.zeros_like(policy_chosen_score), reference_chosen_score - policy_chosen_score)
|
||||||
losses = -(nn.log_sigmoid(beta * logits) - delta * penalty)
|
losses = -(nn.log_sigmoid(beta * logits) - delta * penalty)
|
||||||
else:
|
else:
|
||||||
@ -178,7 +176,7 @@ def evaluate_dpo(
|
|||||||
delta: float,
|
delta: float,
|
||||||
max_seq_length,
|
max_seq_length,
|
||||||
loss_type,
|
loss_type,
|
||||||
loss_fn: callable = dpo_loss
|
loss: callable = dpo_loss
|
||||||
):
|
):
|
||||||
all_losses = 0
|
all_losses = 0
|
||||||
all_rewards = mx.zeros((2,))
|
all_rewards = mx.zeros((2,))
|
||||||
@ -197,7 +195,7 @@ def evaluate_dpo(
|
|||||||
):
|
):
|
||||||
chosen, rejected, chosen_masks, rejected_masks = batch
|
chosen, rejected, chosen_masks, rejected_masks = batch
|
||||||
|
|
||||||
loss, reward, toks, metrics = loss_fn(
|
loss, reward, toks, metrics = loss(
|
||||||
model=model,
|
model=model,
|
||||||
ref_model=ref_model,
|
ref_model=ref_model,
|
||||||
chosen=chosen,
|
chosen=chosen,
|
||||||
@ -239,7 +237,7 @@ def train_dpo(
|
|||||||
train_dataset,
|
train_dataset,
|
||||||
val_dataset,
|
val_dataset,
|
||||||
args: DPOTrainingArgs = DPOTrainingArgs(),
|
args: DPOTrainingArgs = DPOTrainingArgs(),
|
||||||
loss_fn: callable = dpo_loss,
|
loss: callable = dpo_loss,
|
||||||
training_callback: TrainingCallback = None,
|
training_callback: TrainingCallback = None,
|
||||||
loss_type="sigmoid",
|
loss_type="sigmoid",
|
||||||
):
|
):
|
||||||
@ -258,7 +256,7 @@ def train_dpo(
|
|||||||
def step(batch):
|
def step(batch):
|
||||||
chosen, rejected, chosen_masks, rejected_masks = batch
|
chosen, rejected, chosen_masks, rejected_masks = batch
|
||||||
|
|
||||||
(loss, reward, toks, metrics), grad = loss_value_and_grad(
|
(lvalue, reward, toks, metrics), grad = loss_value_and_grad(
|
||||||
model,
|
model,
|
||||||
ref_model,
|
ref_model,
|
||||||
chosen,
|
chosen,
|
||||||
@ -270,10 +268,10 @@ def train_dpo(
|
|||||||
grad = average_gradients(grad)
|
grad = average_gradients(grad)
|
||||||
optimizer.update(model, grad)
|
optimizer.update(model, grad)
|
||||||
|
|
||||||
return loss, reward, toks, metrics
|
return lvalue, reward, toks, metrics
|
||||||
|
|
||||||
def loss_wrapper(model, ref_model, chosen, rejected, chosen_masks, rejected_masks):
|
def loss_wrapper(model, ref_model, chosen, rejected, chosen_masks, rejected_masks):
|
||||||
return loss_fn(
|
return loss(
|
||||||
model=model,
|
model=model,
|
||||||
reference_teacher_model=ref_model,
|
reference_teacher_model=ref_model,
|
||||||
chosen=chosen,
|
chosen=chosen,
|
||||||
@ -311,7 +309,6 @@ def train_dpo(
|
|||||||
train=True,
|
train=True,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
# Report validation loss if needed
|
|
||||||
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
|
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
|
||||||
stop = time.perf_counter()
|
stop = time.perf_counter()
|
||||||
val_loss, val_rewards, val_ntokens, val_metrics = evaluate_dpo(
|
val_loss, val_rewards, val_ntokens, val_metrics = evaluate_dpo(
|
||||||
@ -321,7 +318,7 @@ def train_dpo(
|
|||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
num_batches=args.val_batches,
|
num_batches=args.val_batches,
|
||||||
max_seq_length=args.max_seq_length,
|
max_seq_length=args.max_seq_length,
|
||||||
loss_fn=loss_fn,
|
loss=loss,
|
||||||
beta=args.beta,
|
beta=args.beta,
|
||||||
delta=args.delta,
|
delta=args.delta,
|
||||||
loss_type=loss_type,
|
loss_type=loss_type,
|
||||||
@ -351,13 +348,15 @@ def train_dpo(
|
|||||||
|
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
|
|
||||||
loss, reward, toks, metrics = step(batch)
|
lvalue, reward, toks, metrics = step(batch)
|
||||||
losses += loss
|
losses += lvalue
|
||||||
rewards += reward
|
rewards += reward
|
||||||
n_tokens += toks
|
n_tokens += toks
|
||||||
steps += 1
|
steps += 1
|
||||||
|
|
||||||
for k, v in metrics.items():
|
for k, v in metrics.items():
|
||||||
accumulated_metrics[k] += v
|
accumulated_metrics[k] += v
|
||||||
|
|
||||||
mx.eval(state, losses, rewards, n_tokens)
|
mx.eval(state, losses, rewards, n_tokens)
|
||||||
|
|
||||||
if it % args.steps_per_report == 0 or it == args.iters:
|
if it % args.steps_per_report == 0 or it == args.iters:
|
||||||
|
Loading…
Reference in New Issue
Block a user