mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 18:51:18 +08:00
cleaning up
This commit is contained in:
parent
ceccb4c9e9
commit
541677aa7f
@ -66,7 +66,6 @@ CONFIG_DEFAULTS = {
|
|||||||
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
||||||
"beta": 0.1,
|
"beta": 0.1,
|
||||||
"dpo_loss_type": "sigmoid",
|
"dpo_loss_type": "sigmoid",
|
||||||
"is_reference_free": False,
|
|
||||||
"delta": 50.0,
|
"delta": 50.0,
|
||||||
"reference_model_path": None,
|
"reference_model_path": None,
|
||||||
"reward_scaling": 1.0,
|
"reward_scaling": 1.0,
|
||||||
@ -174,13 +173,21 @@ def build_parser():
|
|||||||
help="Use gradient checkpointing to reduce memory use.",
|
help="Use gradient checkpointing to reduce memory use.",
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
parser.add_argument("--beta", type=float)
|
|
||||||
parser.add_argument("--dpo-loss-type", type=str, choices=["sigmoid", "hinge", "ipo", "dpo"])
|
|
||||||
parser.add_argument("--is-reference-free", action="store_true")
|
|
||||||
parser.add_argument("--delta", type=float)
|
|
||||||
parser.add_argument("--reference-model-path", type=str)
|
|
||||||
parser.add_argument("--reward-scaling", type=float, help="Scaling factor for offline rewards.")
|
|
||||||
parser.add_argument("--seed", type=int, help="The PRNG seed.")
|
parser.add_argument("--seed", type=int, help="The PRNG seed.")
|
||||||
|
|
||||||
|
# ORPO args
|
||||||
|
parser.add_argument(
|
||||||
|
"--beta",
|
||||||
|
type=float,
|
||||||
|
help="Temperature parameter for ORPO training.",
|
||||||
|
default=0.1
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--reward-scaling",
|
||||||
|
type=float,
|
||||||
|
help="Reward scaling factor for ORPO training, not implemented.",
|
||||||
|
default=1.0
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -239,7 +246,8 @@ def train_model(
|
|||||||
adapter_file=adapter_file,
|
adapter_file=adapter_file,
|
||||||
max_seq_length=args.max_seq_length,
|
max_seq_length=args.max_seq_length,
|
||||||
grad_checkpoint=args.grad_checkpoint,
|
grad_checkpoint=args.grad_checkpoint,
|
||||||
beta=args.beta
|
beta=args.beta,
|
||||||
|
reward_scaling=args.reward_scaling
|
||||||
)
|
)
|
||||||
|
|
||||||
train_orpo(
|
train_orpo(
|
||||||
@ -288,7 +296,7 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set
|
|||||||
max_seq_length=args.max_seq_length,
|
max_seq_length=args.max_seq_length,
|
||||||
beta=args.beta
|
beta=args.beta
|
||||||
)
|
)
|
||||||
print(f"Test loss {test_loss:.8f}, Rewards: {test_rewards[0]:.3f}, {test_rewards[1]:.3f}")
|
print(f"Test loss {test_loss:.3f}, Rewards: {test_rewards[0]:.3f}, {test_rewards[1]:.3f}")
|
||||||
else:
|
else:
|
||||||
test_loss = evaluate(
|
test_loss = evaluate(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -16,6 +16,10 @@ class ORPOTrainingArgs(TrainingArgs):
|
|||||||
default=0.1,
|
default=0.1,
|
||||||
metadata={"help": "Temperature parameter for ORPO training."}
|
metadata={"help": "Temperature parameter for ORPO training."}
|
||||||
)
|
)
|
||||||
|
reward_scaling: float = field(
|
||||||
|
default=1.0,
|
||||||
|
metadata={"help": "Reward scaling factor for ORPO training, not implemented."}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def orpo_loss(model, chosen, rejected, chosen_masks, rejected_masks, preference_scores, beta=0.1):
|
def orpo_loss(model, chosen, rejected, chosen_masks, rejected_masks, preference_scores, beta=0.1):
|
||||||
@ -131,7 +135,7 @@ def evaluate_orpo(model, dataset, batch_size, num_batches, beta: float, max_seq_
|
|||||||
),
|
),
|
||||||
):
|
):
|
||||||
chosen, rejected, chosen_masks, rejected_masks, preference_scores = batch
|
chosen, rejected, chosen_masks, rejected_masks, preference_scores = batch
|
||||||
loss, reward, toks, metrics = orpo_loss(
|
lvalue, reward, toks, metrics = orpo_loss(
|
||||||
model=model,
|
model=model,
|
||||||
chosen=chosen,
|
chosen=chosen,
|
||||||
rejected=rejected,
|
rejected=rejected,
|
||||||
@ -140,7 +144,7 @@ def evaluate_orpo(model, dataset, batch_size, num_batches, beta: float, max_seq_
|
|||||||
preference_scores=preference_scores,
|
preference_scores=preference_scores,
|
||||||
beta=beta
|
beta=beta
|
||||||
)
|
)
|
||||||
all_losses += loss * toks
|
all_losses += lvalue * toks
|
||||||
all_rewards += reward * toks
|
all_rewards += reward * toks
|
||||||
ntokens += toks
|
ntokens += toks
|
||||||
|
|
||||||
@ -169,6 +173,7 @@ def train_orpo(
|
|||||||
optimizer,
|
optimizer,
|
||||||
train_dataset,
|
train_dataset,
|
||||||
val_dataset,
|
val_dataset,
|
||||||
|
loss: callable = orpo_loss,
|
||||||
args: ORPOTrainingArgs = ORPOTrainingArgs(),
|
args: ORPOTrainingArgs = ORPOTrainingArgs(),
|
||||||
training_callback: TrainingCallback = None,
|
training_callback: TrainingCallback = None,
|
||||||
):
|
):
|
||||||
@ -188,7 +193,7 @@ def train_orpo(
|
|||||||
def step(batch):
|
def step(batch):
|
||||||
chosen, rejected, chosen_masks, rejected_masks, preference_scores = batch
|
chosen, rejected, chosen_masks, rejected_masks, preference_scores = batch
|
||||||
|
|
||||||
(loss, reward, toks, metrics), grad = loss_value_and_grad(
|
(lvalue, reward, toks, metrics), grad = loss_value_and_grad(
|
||||||
model,
|
model,
|
||||||
chosen,
|
chosen,
|
||||||
rejected,
|
rejected,
|
||||||
@ -200,10 +205,10 @@ def train_orpo(
|
|||||||
grad = average_gradients(grad)
|
grad = average_gradients(grad)
|
||||||
optimizer.update(model, grad)
|
optimizer.update(model, grad)
|
||||||
|
|
||||||
return loss, reward, toks, metrics
|
return lvalue, reward, toks, metrics
|
||||||
|
|
||||||
def loss_wrapper(model, chosen, rejected, chosen_masks, rejected_masks, preference_scores):
|
def loss_wrapper(model, chosen, rejected, chosen_masks, rejected_masks, preference_scores):
|
||||||
return orpo_loss(
|
return loss(
|
||||||
model=model,
|
model=model,
|
||||||
chosen=chosen,
|
chosen=chosen,
|
||||||
rejected=rejected,
|
rejected=rejected,
|
||||||
@ -254,7 +259,7 @@ def train_orpo(
|
|||||||
if rank == 0:
|
if rank == 0:
|
||||||
print(
|
print(
|
||||||
f"Iter {it}: "
|
f"Iter {it}: "
|
||||||
f"Val loss {val_loss:.8f}, "
|
f"Val loss {val_loss:.3f}, "
|
||||||
f"Val chosen reward {val_rewards[0]:.3f}, "
|
f"Val chosen reward {val_rewards[0]:.3f}, "
|
||||||
f"Val rejected reward {val_rewards[1]:.3f}, "
|
f"Val rejected reward {val_rewards[1]:.3f}, "
|
||||||
f"Val accuracy {val_metrics['accuracies']:.3f}, "
|
f"Val accuracy {val_metrics['accuracies']:.3f}, "
|
||||||
@ -276,13 +281,15 @@ def train_orpo(
|
|||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
|
|
||||||
# Training step
|
# Training step
|
||||||
loss, reward, toks, metrics = step(batch)
|
lvalue, reward, toks, metrics = step(batch)
|
||||||
losses += loss
|
losses += lvalue
|
||||||
rewards += reward
|
rewards += reward
|
||||||
n_tokens += toks
|
n_tokens += toks
|
||||||
steps += 1
|
steps += 1
|
||||||
|
|
||||||
for k, v in metrics.items():
|
for k, v in metrics.items():
|
||||||
accumulated_metrics[k] += v
|
accumulated_metrics[k] += v
|
||||||
|
|
||||||
mx.eval(state, losses, rewards, n_tokens)
|
mx.eval(state, losses, rewards, n_tokens)
|
||||||
|
|
||||||
if it % args.steps_per_report == 0 or it == args.iters:
|
if it % args.steps_per_report == 0 or it == args.iters:
|
||||||
@ -300,7 +307,7 @@ def train_orpo(
|
|||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
print(
|
print(
|
||||||
f"Iter {it}: Train loss {train_loss:.8f}, "
|
f"Iter {it}: Train loss {train_loss:.3f}, "
|
||||||
f"Chosen reward {train_rewards[0]:.3f}, "
|
f"Chosen reward {train_rewards[0]:.3f}, "
|
||||||
f"Rejected reward {train_rewards[1]:.3f}, "
|
f"Rejected reward {train_rewards[1]:.3f}, "
|
||||||
f"Accuracy {avg_metrics['accuracies']:.3f}, "
|
f"Accuracy {avg_metrics['accuracies']:.3f}, "
|
||||||
|
Loading…
Reference in New Issue
Block a user