cleaning up

This commit is contained in:
Goekdeniz-Guelmez
2025-01-31 21:36:24 +01:00
parent ceccb4c9e9
commit 541677aa7f
2 changed files with 34 additions and 19 deletions

View File

@@ -16,6 +16,10 @@ class ORPOTrainingArgs(TrainingArgs):
default=0.1,
metadata={"help": "Temperature parameter for ORPO training."}
)
reward_scaling: float = field(
default=1.0,
metadata={"help": "Reward scaling factor for ORPO training, not implemented."}
)
def orpo_loss(model, chosen, rejected, chosen_masks, rejected_masks, preference_scores, beta=0.1):
@@ -131,7 +135,7 @@ def evaluate_orpo(model, dataset, batch_size, num_batches, beta: float, max_seq_
),
):
chosen, rejected, chosen_masks, rejected_masks, preference_scores = batch
loss, reward, toks, metrics = orpo_loss(
lvalue, reward, toks, metrics = orpo_loss(
model=model,
chosen=chosen,
rejected=rejected,
@@ -140,7 +144,7 @@ def evaluate_orpo(model, dataset, batch_size, num_batches, beta: float, max_seq_
preference_scores=preference_scores,
beta=beta
)
all_losses += loss * toks
all_losses += lvalue * toks
all_rewards += reward * toks
ntokens += toks
@@ -169,6 +173,7 @@ def train_orpo(
optimizer,
train_dataset,
val_dataset,
loss: callable = orpo_loss,
args: ORPOTrainingArgs = ORPOTrainingArgs(),
training_callback: TrainingCallback = None,
):
@@ -188,7 +193,7 @@ def train_orpo(
def step(batch):
chosen, rejected, chosen_masks, rejected_masks, preference_scores = batch
(loss, reward, toks, metrics), grad = loss_value_and_grad(
(lvalue, reward, toks, metrics), grad = loss_value_and_grad(
model,
chosen,
rejected,
@@ -200,10 +205,10 @@ def train_orpo(
grad = average_gradients(grad)
optimizer.update(model, grad)
return loss, reward, toks, metrics
return lvalue, reward, toks, metrics
def loss_wrapper(model, chosen, rejected, chosen_masks, rejected_masks, preference_scores):
return orpo_loss(
return loss(
model=model,
chosen=chosen,
rejected=rejected,
@@ -254,7 +259,7 @@ def train_orpo(
if rank == 0:
print(
f"Iter {it}: "
f"Val loss {val_loss:.8f}, "
f"Val loss {val_loss:.3f}, "
f"Val chosen reward {val_rewards[0]:.3f}, "
f"Val rejected reward {val_rewards[1]:.3f}, "
f"Val accuracy {val_metrics['accuracies']:.3f}, "
@@ -276,13 +281,15 @@ def train_orpo(
start = time.perf_counter()
# Training step
loss, reward, toks, metrics = step(batch)
losses += loss
lvalue, reward, toks, metrics = step(batch)
losses += lvalue
rewards += reward
n_tokens += toks
steps += 1
for k, v in metrics.items():
accumulated_metrics[k] += v
mx.eval(state, losses, rewards, n_tokens)
if it % args.steps_per_report == 0 or it == args.iters:
@@ -300,7 +307,7 @@ def train_orpo(
if rank == 0:
print(
f"Iter {it}: Train loss {train_loss:.8f}, "
f"Iter {it}: Train loss {train_loss:.3f}, "
f"Chosen reward {train_rewards[0]:.3f}, "
f"Rejected reward {train_rewards[1]:.3f}, "
f"Accuracy {avg_metrics['accuracies']:.3f}, "