This commit is contained in:
Goekdeniz-Guelmez 2025-01-25 22:03:32 +01:00
parent 86b315fdf9
commit 0ff1289bd9

View File

@ -176,12 +176,12 @@ def evaluate_dpo(
num_batches,
beta: float,
delta: float,
max_seq_length=2048,
loss_fn: callable = dpo_loss,
loss_type="sigmoid",
max_seq_length,
loss_type,
loss_fn: callable = dpo_loss
):
all_losses = 0
all_rewards = mx.zeros((2,)) # [chosen_reward, rejected_reward]
all_rewards = mx.zeros((2,))
ntokens = 0
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
@ -195,6 +195,7 @@ def evaluate_dpo(
),
):
chosen, rejected, chosen_masks, rejected_masks = batch
loss, reward, toks = loss_fn(
model=model,
reference_teacher_model=reference_model,
@ -206,18 +207,18 @@ def evaluate_dpo(
beta=beta,
delta=delta,
)
all_losses += loss * toks
all_rewards += reward
ntokens += toks
mx.eval(all_losses, all_rewards, ntokens)
all_losses = mx.distributed.all_sum(all_losses)
all_rewards = mx.distributed.all_sum(all_rewards)
ntokens = mx.distributed.all_sum(ntokens)
ntokens = mx.distributed.all_sum(ntokens)
return (all_losses / ntokens).item(), all_rewards.tolist()
def train_dpo(
model,
reference_model,