mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-28 03:41:17 +08:00
updates
This commit is contained in:
parent
86b315fdf9
commit
0ff1289bd9
@ -176,12 +176,12 @@ def evaluate_dpo(
|
|||||||
num_batches,
|
num_batches,
|
||||||
beta: float,
|
beta: float,
|
||||||
delta: float,
|
delta: float,
|
||||||
max_seq_length=2048,
|
max_seq_length,
|
||||||
loss_fn: callable = dpo_loss,
|
loss_type,
|
||||||
loss_type="sigmoid",
|
loss_fn: callable = dpo_loss
|
||||||
):
|
):
|
||||||
all_losses = 0
|
all_losses = 0
|
||||||
all_rewards = mx.zeros((2,)) # [chosen_reward, rejected_reward]
|
all_rewards = mx.zeros((2,))
|
||||||
ntokens = 0
|
ntokens = 0
|
||||||
|
|
||||||
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
||||||
@ -195,6 +195,7 @@ def evaluate_dpo(
|
|||||||
),
|
),
|
||||||
):
|
):
|
||||||
chosen, rejected, chosen_masks, rejected_masks = batch
|
chosen, rejected, chosen_masks, rejected_masks = batch
|
||||||
|
|
||||||
loss, reward, toks = loss_fn(
|
loss, reward, toks = loss_fn(
|
||||||
model=model,
|
model=model,
|
||||||
reference_teacher_model=reference_model,
|
reference_teacher_model=reference_model,
|
||||||
@ -206,18 +207,18 @@ def evaluate_dpo(
|
|||||||
beta=beta,
|
beta=beta,
|
||||||
delta=delta,
|
delta=delta,
|
||||||
)
|
)
|
||||||
|
|
||||||
all_losses += loss * toks
|
all_losses += loss * toks
|
||||||
all_rewards += reward
|
all_rewards += reward
|
||||||
ntokens += toks
|
ntokens += toks
|
||||||
mx.eval(all_losses, all_rewards, ntokens)
|
|
||||||
|
|
||||||
all_losses = mx.distributed.all_sum(all_losses)
|
all_losses = mx.distributed.all_sum(all_losses)
|
||||||
all_rewards = mx.distributed.all_sum(all_rewards)
|
all_rewards = mx.distributed.all_sum(all_rewards)
|
||||||
|
|
||||||
ntokens = mx.distributed.all_sum(ntokens)
|
ntokens = mx.distributed.all_sum(ntokens)
|
||||||
|
|
||||||
return (all_losses / ntokens).item(), all_rewards.tolist()
|
return (all_losses / ntokens).item(), all_rewards.tolist()
|
||||||
|
|
||||||
|
|
||||||
def train_dpo(
|
def train_dpo(
|
||||||
model,
|
model,
|
||||||
reference_model,
|
reference_model,
|
||||||
|
Loading…
Reference in New Issue
Block a user