mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 02:33:23 +08:00
Removed rejected_rewards handling, Updated batch unpacking to match iterator, Updated batch unpacking to match iterator, Added preference score scaling, Simplified reward calculation, Removed redundant rejected_rewards
This commit is contained in:
parent
09ed837896
commit
d8e7834345
@ -18,52 +18,50 @@ class ORPOTrainingArgs(TrainingArgs):
|
||||
)
|
||||
|
||||
|
||||
def orpo_loss(model, chosen, rejected, chosen_masks, rejected_masks, chosen_rewards, rejected_rewards, beta=0.1):
|
||||
def get_logps(model, x, mask):
|
||||
inputs = x[:, :-1]
|
||||
targets = x[:, 1:]
|
||||
logits = model(inputs)
|
||||
logp = -nn.losses.cross_entropy(logits, targets, reduction='none')
|
||||
seq_lengths = mask[:, :-1].sum(-1)
|
||||
logp_sum = (logp * mask[:, :-1]).sum(-1) / seq_lengths
|
||||
logits_mean = (logits * mask[:, :-1, None]).sum() / mask[:, :-1].sum()
|
||||
return logp_sum, logits_mean
|
||||
def orpo_loss(model, chosen, rejected, chosen_masks, rejected_masks, preference_scores, beta=0.1):
|
||||
def get_logps(model, x, mask):
|
||||
inputs = x[:, :-1]
|
||||
targets = x[:, 1:]
|
||||
logits = model(inputs)
|
||||
logp = -nn.losses.cross_entropy(logits, targets, reduction='none')
|
||||
seq_lengths = mask[:, :-1].sum(-1)
|
||||
logp_sum = (logp * mask[:, :-1]).sum(-1) / seq_lengths
|
||||
logits_mean = (logits * mask[:, :-1, None]).sum() / mask[:, :-1].sum()
|
||||
return logp_sum, logits_mean
|
||||
|
||||
policy_chosen_logps, chosen_logits_mean = get_logps(model, chosen, chosen_masks)
|
||||
policy_rejected_logps, rejected_logits_mean = get_logps(model, rejected, rejected_masks)
|
||||
|
||||
log_odds = (policy_chosen_logps - policy_rejected_logps) - (
|
||||
mx.log1p(-mx.exp(policy_chosen_logps)) - mx.log1p(-mx.exp(policy_rejected_logps))
|
||||
)
|
||||
|
||||
ratio = nn.log_sigmoid(log_odds)
|
||||
loss = -beta * ratio
|
||||
|
||||
accuracies = (log_odds > 0).astype(mx.float32)
|
||||
margins = mx.mean(ratio - 1)
|
||||
metrics = {
|
||||
'accuracies': mx.mean(accuracies),
|
||||
'margins': margins,
|
||||
'policy_rejected_logps': mx.mean(policy_rejected_logps),
|
||||
'policy_chosen_logps': mx.mean(policy_chosen_logps),
|
||||
'rejected_logits_mean': mx.mean(rejected_logits_mean),
|
||||
'chosen_logits_mean': mx.mean(chosen_logits_mean)
|
||||
}
|
||||
|
||||
chosen_reward = beta * policy_chosen_logps
|
||||
rejected_reward = beta * policy_rejected_logps
|
||||
reward = mx.stack([mx.mean(chosen_reward), mx.mean(rejected_reward)])
|
||||
num_tokens = chosen_masks.sum() + rejected_masks.sum()
|
||||
|
||||
return mx.mean(loss), reward, num_tokens, metrics
|
||||
policy_chosen_logps, chosen_logits_mean = get_logps(model, chosen, chosen_masks)
|
||||
policy_rejected_logps, rejected_logits_mean = get_logps(model, rejected, rejected_masks)
|
||||
|
||||
# Apply preference scores
|
||||
policy_chosen_logps = policy_chosen_logps * preference_scores
|
||||
|
||||
log_odds = (policy_chosen_logps - policy_rejected_logps) - (
|
||||
mx.log1p(-mx.exp(policy_chosen_logps)) - mx.log1p(-mx.exp(policy_rejected_logps))
|
||||
)
|
||||
|
||||
ratio = nn.log_sigmoid(log_odds)
|
||||
loss = -beta * ratio
|
||||
|
||||
metrics = {
|
||||
'accuracies': mx.mean((log_odds > 0).astype(mx.float32)),
|
||||
'margins': mx.mean(ratio - 1),
|
||||
'policy_rejected_logps': mx.mean(policy_rejected_logps),
|
||||
'policy_chosen_logps': mx.mean(policy_chosen_logps),
|
||||
'rejected_logits_mean': mx.mean(rejected_logits_mean),
|
||||
'chosen_logits_mean': mx.mean(chosen_logits_mean)
|
||||
}
|
||||
|
||||
chosen_reward = beta * policy_chosen_logps
|
||||
rejected_reward = beta * policy_rejected_logps
|
||||
reward = mx.stack([mx.mean(chosen_reward), mx.mean(rejected_reward)])
|
||||
|
||||
num_tokens = chosen_masks.sum() + rejected_masks.sum()
|
||||
|
||||
return mx.mean(loss), reward, num_tokens, metrics
|
||||
|
||||
|
||||
def iterate_orpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
||||
"""
|
||||
Modified batch iterator for ORPO that includes preference scores.
|
||||
Works with pre-tokenized input data.
|
||||
"""
|
||||
# Sort pairs by length of the chosen response
|
||||
def iterate_orpo_batches(dataset, batch_size, max_seq_length, tokenizer, train=False):
|
||||
"""Batch iterator for ORPO with preference scores"""
|
||||
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]['chosen']))
|
||||
|
||||
if len(dataset) < batch_size:
|
||||
@ -71,70 +69,54 @@ def iterate_orpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=F
|
||||
f"Dataset must have at least batch_size={batch_size}"
|
||||
f" examples but only has {len(dataset)}."
|
||||
)
|
||||
|
||||
|
||||
step = mx.distributed.init().size()
|
||||
if batch_size % step != 0:
|
||||
raise ValueError("The batch size must be divisible by the number of workers")
|
||||
|
||||
batch_idx = [
|
||||
idx[i : i + batch_size : step]
|
||||
for i in range(0, len(idx) - batch_size + 1, batch_size)
|
||||
]
|
||||
|
||||
raise ValueError("Batch size must be divisible by number of workers")
|
||||
|
||||
batch_idx = [idx[i:i + batch_size:step] for i in range(0, len(idx) - batch_size + 1, batch_size)]
|
||||
|
||||
while True:
|
||||
indices = np.random.permutation(len(batch_idx)) if train else range(len(batch_idx))
|
||||
for i in indices:
|
||||
batch = [dataset[j] for j in batch_idx[i]]
|
||||
|
||||
# Get lengths assuming data is already tokenized
|
||||
chosen_lengths = [len(x['chosen']) for x in batch]
|
||||
rejected_lengths = [len(x['rejected']) for x in batch]
|
||||
max_length = max(max(chosen_lengths), max(rejected_lengths))
|
||||
|
||||
if max_length > max_seq_length:
|
||||
print(
|
||||
f"[WARNING] Sequences longer than {max_seq_length} tokens "
|
||||
f"will be truncated."
|
||||
)
|
||||
|
||||
max_length = min(max(max(chosen_lengths), max(rejected_lengths)), max_seq_length)
|
||||
pad_to = 8
|
||||
max_length_in_batch = pad_to * ((max_length + pad_to - 1) // pad_to)
|
||||
max_length_in_batch = min(max_length_in_batch, max_seq_length)
|
||||
|
||||
chosen_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
|
||||
rejected_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
|
||||
chosen_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32)
|
||||
rejected_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32)
|
||||
|
||||
# Get preference scores and convert to rewards
|
||||
preference_scores = [x.get('preference_score', 1.0) for x in batch]
|
||||
chosen_rewards = np.array(preference_scores, np.float32)
|
||||
rejected_rewards = np.array([1.0 - score for score in preference_scores], np.float32)
|
||||
|
||||
for j in range(batch_size // step):
|
||||
# Use pre-tokenized sequences directly
|
||||
chosen_length = min(chosen_lengths[j], max_seq_length)
|
||||
batch_size_per_device = batch_size // step
|
||||
chosen_arr = np.zeros((batch_size_per_device, max_length_in_batch), np.int32)
|
||||
rejected_arr = np.zeros((batch_size_per_device, max_length_in_batch), np.int32)
|
||||
chosen_masks = np.zeros((batch_size_per_device, max_length_in_batch), np.float32)
|
||||
rejected_masks = np.zeros((batch_size_per_device, max_length_in_batch), np.float32)
|
||||
|
||||
preference_scores = np.array([x.get('preference_score', 1.0) for x in batch], np.float32)
|
||||
|
||||
for j in range(batch_size_per_device):
|
||||
chosen_length = min(chosen_lengths[j], max_length_in_batch)
|
||||
rejected_length = min(rejected_lengths[j], max_length_in_batch)
|
||||
|
||||
chosen_arr[j, :chosen_length] = batch[j]['chosen'][:chosen_length]
|
||||
chosen_masks[j, :chosen_length] = 1.0
|
||||
|
||||
rejected_length = min(rejected_lengths[j], max_seq_length)
|
||||
rejected_arr[j, :rejected_length] = batch[j]['rejected'][:rejected_length]
|
||||
rejected_masks[j, :rejected_length] = 1.0
|
||||
|
||||
|
||||
yield (
|
||||
mx.array(chosen_arr),
|
||||
mx.array(rejected_arr),
|
||||
mx.array(chosen_masks),
|
||||
mx.array(rejected_masks),
|
||||
mx.array(chosen_rewards),
|
||||
mx.array(rejected_rewards)
|
||||
mx.array(preference_scores)
|
||||
)
|
||||
|
||||
|
||||
if not train:
|
||||
break
|
||||
|
||||
|
||||
def evaluate_orpo(model, dataset, tokenizer, batch_size, num_batches, beta: float, max_seq_length=2048):
|
||||
def evaluate_orpo(model, dataset, batch_size, num_batches, beta: float, max_seq_length=2048):
|
||||
all_losses = 0
|
||||
all_rewards = mx.zeros((2,))
|
||||
all_metrics = None
|
||||
@ -145,20 +127,18 @@ def evaluate_orpo(model, dataset, tokenizer, batch_size, num_batches, beta: floa
|
||||
index_iterator,
|
||||
iterate_orpo_batches(
|
||||
dataset=dataset,
|
||||
tokenizer=tokenizer,
|
||||
batch_size=batch_size,
|
||||
max_seq_length=max_seq_length,
|
||||
),
|
||||
):
|
||||
chosen, rejected, chosen_masks, rejected_masks, chosen_rewards, rejected_rewards = batch
|
||||
chosen, rejected, chosen_masks, rejected_masks, preference_scores = batch
|
||||
loss, reward, toks, metrics = orpo_loss(
|
||||
model=model,
|
||||
chosen=chosen,
|
||||
rejected=rejected,
|
||||
chosen_masks=chosen_masks,
|
||||
rejected_masks=rejected_masks,
|
||||
chosen_rewards=chosen_rewards,
|
||||
rejected_rewards=rejected_rewards,
|
||||
preference_scores=preference_scores,
|
||||
beta=beta
|
||||
)
|
||||
all_losses += loss * toks
|
||||
@ -207,7 +187,7 @@ def train_orpo(
|
||||
state = [model.state, optimizer.state]
|
||||
|
||||
def step(batch):
|
||||
chosen, rejected, chosen_masks, rejected_masks, chosen_rewards, rejected_rewards = batch
|
||||
chosen, rejected, chosen_masks, rejected_masks, preference_scores = batch
|
||||
|
||||
(loss, reward, toks, metrics), grad = loss_value_and_grad(
|
||||
model,
|
||||
@ -215,8 +195,7 @@ def train_orpo(
|
||||
rejected,
|
||||
chosen_masks,
|
||||
rejected_masks,
|
||||
chosen_rewards,
|
||||
rejected_rewards
|
||||
preference_scores=preference_scores,
|
||||
)
|
||||
|
||||
grad = average_gradients(grad)
|
||||
@ -224,16 +203,14 @@ def train_orpo(
|
||||
|
||||
return loss, reward, toks, metrics
|
||||
|
||||
def loss_wrapper(model, chosen, rejected, chosen_masks, rejected_masks,
|
||||
chosen_rewards, rejected_rewards):
|
||||
def loss_wrapper(model, chosen, rejected, chosen_masks, rejected_masks, preference_scores):
|
||||
return orpo_loss(
|
||||
model=model,
|
||||
chosen=chosen,
|
||||
rejected=rejected,
|
||||
chosen_masks=chosen_masks,
|
||||
rejected_masks=rejected_masks,
|
||||
chosen_rewards=chosen_rewards,
|
||||
rejected_rewards=rejected_rewards,
|
||||
preference_scores=preference_scores,
|
||||
beta=args.beta
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user