more metrics

This commit is contained in:
Goekdeniz-Guelmez 2025-01-26 15:09:55 +01:00
parent 0ff1289bd9
commit 4d0e52f7c8

View File

@ -113,14 +113,23 @@ def dpo_loss(
else: else:
raise ValueError(f"Unknown loss type: {loss_type}") raise ValueError(f"Unknown loss type: {loss_type}")
loss = mx.mean(losses)
num_tokens = (num_chosen_tokens + num_rejected_tokens).sum() num_tokens = (num_chosen_tokens + num_rejected_tokens).sum()
chosen_reward = beta * mx.mean(policy_chosen_score - reference_chosen_score) chosen_reward = beta * mx.mean(policy_chosen_score - reference_chosen_score)
rejected_reward = beta * mx.mean(policy_rejected_score - reference_rejected_score) rejected_reward = beta * mx.mean(policy_rejected_score - reference_rejected_score)
reward = mx.stack([chosen_reward, rejected_reward]) reward = mx.stack([chosen_reward, rejected_reward])
return loss, reward, num_tokens metrics = {
'accuracies': mx.mean((chosen_reward > rejected_reward).astype(mx.float32)),
'margins': mx.mean(chosen_reward - rejected_reward),
'policy_rejected_logps': mx.mean(policy_rejected_score / num_rejected_tokens),
'policy_chosen_logps': mx.mean(policy_chosen_score / num_chosen_tokens),
'rejected_logits_mean': mx.mean(policy_rejected_score),
'chosen_logits_mean': mx.mean(policy_chosen_score)
}
return mx.mean(losses), reward, num_tokens, metrics
def iterate_dpo_batches(dataset, batch_size, max_seq_length, train=False): def iterate_dpo_batches(dataset, batch_size, max_seq_length, train=False):
@ -182,6 +191,7 @@ def evaluate_dpo(
): ):
all_losses = 0 all_losses = 0
all_rewards = mx.zeros((2,)) all_rewards = mx.zeros((2,))
all_metrics = None
ntokens = 0 ntokens = 0
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
@ -196,7 +206,7 @@ def evaluate_dpo(
): ):
chosen, rejected, chosen_masks, rejected_masks = batch chosen, rejected, chosen_masks, rejected_masks = batch
loss, reward, toks = loss_fn( loss, reward, toks, metrics = loss_fn(
model=model, model=model,
reference_teacher_model=reference_model, reference_teacher_model=reference_model,
chosen=chosen, chosen=chosen,
@ -211,12 +221,23 @@ def evaluate_dpo(
all_rewards += reward all_rewards += reward
ntokens += toks ntokens += toks
if all_metrics is None:
all_metrics = {k: v * toks for k, v in metrics.items()}
else:
for k, v in metrics.items():
all_metrics[k] += v * toks
mx.eval(all_losses, all_rewards, ntokens)
all_losses = mx.distributed.all_sum(all_losses) all_losses = mx.distributed.all_sum(all_losses)
all_rewards = mx.distributed.all_sum(all_rewards) all_rewards = mx.distributed.all_sum(all_rewards)
ntokens = mx.distributed.all_sum(ntokens) ntokens = mx.distributed.all_sum(ntokens)
all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()}
return (all_losses / ntokens).item(), all_rewards.tolist() avg_metrics = {k: (v / ntokens).item() for k, v in all_metrics.items()}
avg_rewards = (all_rewards / ntokens).tolist()
avg_loss = (all_losses / ntokens).item()
return avg_loss, avg_rewards, ntokens, avg_metrics
def train_dpo( def train_dpo(
@ -246,8 +267,7 @@ def train_dpo(
def step(batch): def step(batch):
chosen, rejected, chosen_masks, rejected_masks = batch chosen, rejected, chosen_masks, rejected_masks = batch
# Remove loss_type from the call (loss, reward, toks, metrics), grad = loss_value_and_grad(
(loss, reward, toks), grad = loss_value_and_grad(
model, model,
reference_model, reference_model,
chosen, chosen,
@ -256,15 +276,11 @@ def train_dpo(
rejected_masks rejected_masks
) )
# All reduce the gradients if running in distributed mode
grad = average_gradients(grad) grad = average_gradients(grad)
# Model update
optimizer.update(model, grad) optimizer.update(model, grad)
return loss, reward, toks return loss, reward, toks, metrics
# Create a wrapper function that includes all required arguments
def loss_wrapper(model, ref_model, chosen, rejected, chosen_masks, rejected_masks): def loss_wrapper(model, ref_model, chosen, rejected, chosen_masks, rejected_masks):
return loss_fn( return loss_fn(
model=model, model=model,
@ -279,7 +295,6 @@ def train_dpo(
is_reference_free=args.is_reference_free is_reference_free=args.is_reference_free
) )
# Create value_and_grad with the wrapper
loss_value_and_grad = nn.value_and_grad(model, loss_wrapper) loss_value_and_grad = nn.value_and_grad(model, loss_wrapper)
losses = 0 losses = 0
@ -287,8 +302,15 @@ def train_dpo(
n_tokens = 0 n_tokens = 0
steps = 0 steps = 0
trained_tokens = 0 trained_tokens = 0
accumulated_metrics = {
'accuracies': 0,
'margins': 0,
'policy_rejected_logps': 0,
'policy_chosen_logps': 0,
'rejected_logits_mean': 0,
'chosen_logits_mean': 0
}
# Main training loop
start = time.perf_counter() start = time.perf_counter()
for it, batch in zip( for it, batch in zip(
range(1, args.iters + 1), range(1, args.iters + 1),
@ -302,7 +324,7 @@ def train_dpo(
# Report validation loss if needed # Report validation loss if needed
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters: if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
stop = time.perf_counter() stop = time.perf_counter()
val_loss, val_rewards = evaluate_dpo( val_loss, val_rewards, val_ntokens, val_metrics = evaluate_dpo(
model=model, model=model,
reference_model=reference_model, reference_model=reference_model,
dataset=val_dataset, dataset=val_dataset,
@ -322,37 +344,40 @@ def train_dpo(
f"Val loss {val_loss:.8f}, " f"Val loss {val_loss:.8f}, "
f"Val chosen reward {val_rewards[0]:.3f}, " f"Val chosen reward {val_rewards[0]:.3f}, "
f"Val rejected reward {val_rewards[1]:.3f}, " f"Val rejected reward {val_rewards[1]:.3f}, "
f"Val accuracy {val_metrics['accuracies']:.3f}, "
f"Val margin {val_metrics['margins']:.3f}, "
f"Val took {val_time:.3f}s", f"Val took {val_time:.3f}s",
flush=True, flush=True,
) )
if training_callback is not None: if training_callback is not None:
val_info = { training_callback.on_val_loss_report({
"iteration": it, "iteration": it,
"val_loss": val_loss, "val_loss": val_loss,
"val_chosen_reward": val_rewards[0], "val_chosen_reward": val_rewards[0],
"val_rejected_reward": val_rewards[1], "val_rejected_reward": val_rewards[1],
**{f"val_{k}": v for k, v in val_metrics.items()},
"val_time": val_time, "val_time": val_time,
} })
training_callback.on_val_loss_report(val_info)
start = time.perf_counter() start = time.perf_counter()
loss, reward, toks = step(batch) loss, reward, toks, metrics = step(batch)
losses += loss losses += loss
rewards += reward rewards += reward
n_tokens += toks n_tokens += toks
steps += 1 steps += 1
for k, v in metrics.items():
accumulated_metrics[k] += v
mx.eval(state, losses, rewards, n_tokens) mx.eval(state, losses, rewards, n_tokens)
# Report training loss if needed
if it % args.steps_per_report == 0 or it == args.iters: if it % args.steps_per_report == 0 or it == args.iters:
stop = time.perf_counter() stop = time.perf_counter()
train_loss = mx.distributed.all_sum(losses).item() train_loss = mx.distributed.all_sum(losses).item() / (steps * world_size)
train_loss /= steps * world_size
train_rewards = mx.distributed.all_sum(rewards).tolist() train_rewards = mx.distributed.all_sum(rewards).tolist()
train_rewards = [r / (steps * world_size) for r in train_rewards] train_rewards = [r / (steps * world_size) for r in train_rewards]
avg_metrics = {k: v / (steps * world_size) for k, v in accumulated_metrics.items()}
n_tokens = mx.distributed.all_sum(n_tokens).item() n_tokens = mx.distributed.all_sum(n_tokens).item()
learning_rate = optimizer.learning_rate.item() learning_rate = optimizer.learning_rate.item()
it_sec = args.steps_per_report / (stop - start) it_sec = args.steps_per_report / (stop - start)
@ -365,6 +390,8 @@ def train_dpo(
f"Iter {it}: Train loss {train_loss:.8f}, " f"Iter {it}: Train loss {train_loss:.8f}, "
f"Chosen reward {train_rewards[0]:.3f}, " f"Chosen reward {train_rewards[0]:.3f}, "
f"Rejected reward {train_rewards[1]:.3f}, " f"Rejected reward {train_rewards[1]:.3f}, "
f"Accuracy {avg_metrics['accuracies']:.3f}, "
f"Margin {avg_metrics['margins']:.3f}, "
f"Learning Rate {learning_rate:.3e}, " f"Learning Rate {learning_rate:.3e}, "
f"It/sec {it_sec:.3f}, " f"It/sec {it_sec:.3f}, "
f"Tokens/sec {tokens_sec:.3f}, " f"Tokens/sec {tokens_sec:.3f}, "
@ -379,6 +406,7 @@ def train_dpo(
"train_loss": train_loss, "train_loss": train_loss,
"train_chosen_reward": train_rewards[0], "train_chosen_reward": train_rewards[0],
"train_rejected_reward": train_rewards[1], "train_rejected_reward": train_rewards[1],
**{f"train_{k}": v for k, v in avg_metrics.items()},
"learning_rate": learning_rate, "learning_rate": learning_rate,
"iterations_per_second": it_sec, "iterations_per_second": it_sec,
"tokens_per_second": tokens_sec, "tokens_per_second": tokens_sec,