This commit is contained in:
Goekdeniz-Guelmez 2025-01-25 22:03:32 +01:00
parent 86b315fdf9
commit 0ff1289bd9

View File

@ -176,12 +176,12 @@ def evaluate_dpo(
num_batches, num_batches,
beta: float, beta: float,
delta: float, delta: float,
max_seq_length=2048, max_seq_length,
loss_fn: callable = dpo_loss, loss_type,
loss_type="sigmoid", loss_fn: callable = dpo_loss
): ):
all_losses = 0 all_losses = 0
all_rewards = mx.zeros((2,)) # [chosen_reward, rejected_reward] all_rewards = mx.zeros((2,))
ntokens = 0 ntokens = 0
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
@ -195,6 +195,7 @@ def evaluate_dpo(
), ),
): ):
chosen, rejected, chosen_masks, rejected_masks = batch chosen, rejected, chosen_masks, rejected_masks = batch
loss, reward, toks = loss_fn( loss, reward, toks = loss_fn(
model=model, model=model,
reference_teacher_model=reference_model, reference_teacher_model=reference_model,
@ -206,18 +207,18 @@ def evaluate_dpo(
beta=beta, beta=beta,
delta=delta, delta=delta,
) )
all_losses += loss * toks all_losses += loss * toks
all_rewards += reward all_rewards += reward
ntokens += toks ntokens += toks
mx.eval(all_losses, all_rewards, ntokens)
all_losses = mx.distributed.all_sum(all_losses) all_losses = mx.distributed.all_sum(all_losses)
all_rewards = mx.distributed.all_sum(all_rewards) all_rewards = mx.distributed.all_sum(all_rewards)
ntokens = mx.distributed.all_sum(ntokens) ntokens = mx.distributed.all_sum(ntokens)
return (all_losses / ntokens).item(), all_rewards.tolist() return (all_losses / ntokens).item(), all_rewards.tolist()
def train_dpo( def train_dpo(
model, model,
reference_model, reference_model,