mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-27 03:05:20 +08:00
updates
This commit is contained in:
parent
86b315fdf9
commit
0ff1289bd9
@ -176,12 +176,12 @@ def evaluate_dpo(
|
||||
num_batches,
|
||||
beta: float,
|
||||
delta: float,
|
||||
max_seq_length=2048,
|
||||
loss_fn: callable = dpo_loss,
|
||||
loss_type="sigmoid",
|
||||
max_seq_length,
|
||||
loss_type,
|
||||
loss_fn: callable = dpo_loss
|
||||
):
|
||||
all_losses = 0
|
||||
all_rewards = mx.zeros((2,)) # [chosen_reward, rejected_reward]
|
||||
all_rewards = mx.zeros((2,))
|
||||
ntokens = 0
|
||||
|
||||
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
||||
@ -195,6 +195,7 @@ def evaluate_dpo(
|
||||
),
|
||||
):
|
||||
chosen, rejected, chosen_masks, rejected_masks = batch
|
||||
|
||||
loss, reward, toks = loss_fn(
|
||||
model=model,
|
||||
reference_teacher_model=reference_model,
|
||||
@ -206,18 +207,18 @@ def evaluate_dpo(
|
||||
beta=beta,
|
||||
delta=delta,
|
||||
)
|
||||
|
||||
all_losses += loss * toks
|
||||
all_rewards += reward
|
||||
ntokens += toks
|
||||
mx.eval(all_losses, all_rewards, ntokens)
|
||||
|
||||
all_losses = mx.distributed.all_sum(all_losses)
|
||||
all_rewards = mx.distributed.all_sum(all_rewards)
|
||||
ntokens = mx.distributed.all_sum(ntokens)
|
||||
|
||||
ntokens = mx.distributed.all_sum(ntokens)
|
||||
|
||||
return (all_losses / ntokens).item(), all_rewards.tolist()
|
||||
|
||||
|
||||
def train_dpo(
|
||||
model,
|
||||
reference_model,
|
||||
|
Loading…
Reference in New Issue
Block a user