Removed rejected_rewards handling, Updated batch unpacking to match iterator, Updated batch unpacking to match iterator, Added preference score scaling, Simplified reward calculation, Removed redundant rejected_rewards

This commit is contained in:
Goekdeniz-Guelmez 2025-01-25 21:35:37 +01:00
parent 09ed837896
commit d8e7834345

View File

@ -18,52 +18,50 @@ class ORPOTrainingArgs(TrainingArgs):
)
def orpo_loss(model, chosen, rejected, chosen_masks, rejected_masks, chosen_rewards, rejected_rewards, beta=0.1):
def get_logps(model, x, mask):
inputs = x[:, :-1]
targets = x[:, 1:]
logits = model(inputs)
logp = -nn.losses.cross_entropy(logits, targets, reduction='none')
seq_lengths = mask[:, :-1].sum(-1)
logp_sum = (logp * mask[:, :-1]).sum(-1) / seq_lengths
logits_mean = (logits * mask[:, :-1, None]).sum() / mask[:, :-1].sum()
return logp_sum, logits_mean
def orpo_loss(model, chosen, rejected, chosen_masks, rejected_masks, preference_scores, beta=0.1):
def get_logps(model, x, mask):
inputs = x[:, :-1]
targets = x[:, 1:]
logits = model(inputs)
logp = -nn.losses.cross_entropy(logits, targets, reduction='none')
seq_lengths = mask[:, :-1].sum(-1)
logp_sum = (logp * mask[:, :-1]).sum(-1) / seq_lengths
logits_mean = (logits * mask[:, :-1, None]).sum() / mask[:, :-1].sum()
return logp_sum, logits_mean
policy_chosen_logps, chosen_logits_mean = get_logps(model, chosen, chosen_masks)
policy_rejected_logps, rejected_logits_mean = get_logps(model, rejected, rejected_masks)
log_odds = (policy_chosen_logps - policy_rejected_logps) - (
mx.log1p(-mx.exp(policy_chosen_logps)) - mx.log1p(-mx.exp(policy_rejected_logps))
)
ratio = nn.log_sigmoid(log_odds)
loss = -beta * ratio
accuracies = (log_odds > 0).astype(mx.float32)
margins = mx.mean(ratio - 1)
metrics = {
'accuracies': mx.mean(accuracies),
'margins': margins,
'policy_rejected_logps': mx.mean(policy_rejected_logps),
'policy_chosen_logps': mx.mean(policy_chosen_logps),
'rejected_logits_mean': mx.mean(rejected_logits_mean),
'chosen_logits_mean': mx.mean(chosen_logits_mean)
}
chosen_reward = beta * policy_chosen_logps
rejected_reward = beta * policy_rejected_logps
reward = mx.stack([mx.mean(chosen_reward), mx.mean(rejected_reward)])
num_tokens = chosen_masks.sum() + rejected_masks.sum()
return mx.mean(loss), reward, num_tokens, metrics
policy_chosen_logps, chosen_logits_mean = get_logps(model, chosen, chosen_masks)
policy_rejected_logps, rejected_logits_mean = get_logps(model, rejected, rejected_masks)
# Apply preference scores
policy_chosen_logps = policy_chosen_logps * preference_scores
log_odds = (policy_chosen_logps - policy_rejected_logps) - (
mx.log1p(-mx.exp(policy_chosen_logps)) - mx.log1p(-mx.exp(policy_rejected_logps))
)
ratio = nn.log_sigmoid(log_odds)
loss = -beta * ratio
metrics = {
'accuracies': mx.mean((log_odds > 0).astype(mx.float32)),
'margins': mx.mean(ratio - 1),
'policy_rejected_logps': mx.mean(policy_rejected_logps),
'policy_chosen_logps': mx.mean(policy_chosen_logps),
'rejected_logits_mean': mx.mean(rejected_logits_mean),
'chosen_logits_mean': mx.mean(chosen_logits_mean)
}
chosen_reward = beta * policy_chosen_logps
rejected_reward = beta * policy_rejected_logps
reward = mx.stack([mx.mean(chosen_reward), mx.mean(rejected_reward)])
num_tokens = chosen_masks.sum() + rejected_masks.sum()
return mx.mean(loss), reward, num_tokens, metrics
def iterate_orpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
"""
Modified batch iterator for ORPO that includes preference scores.
Works with pre-tokenized input data.
"""
# Sort pairs by length of the chosen response
def iterate_orpo_batches(dataset, batch_size, max_seq_length, tokenizer, train=False):
"""Batch iterator for ORPO with preference scores"""
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]['chosen']))
if len(dataset) < batch_size:
@ -71,70 +69,54 @@ def iterate_orpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=F
f"Dataset must have at least batch_size={batch_size}"
f" examples but only has {len(dataset)}."
)
step = mx.distributed.init().size()
if batch_size % step != 0:
raise ValueError("The batch size must be divisible by the number of workers")
batch_idx = [
idx[i : i + batch_size : step]
for i in range(0, len(idx) - batch_size + 1, batch_size)
]
raise ValueError("Batch size must be divisible by number of workers")
batch_idx = [idx[i:i + batch_size:step] for i in range(0, len(idx) - batch_size + 1, batch_size)]
while True:
indices = np.random.permutation(len(batch_idx)) if train else range(len(batch_idx))
for i in indices:
batch = [dataset[j] for j in batch_idx[i]]
# Get lengths assuming data is already tokenized
chosen_lengths = [len(x['chosen']) for x in batch]
rejected_lengths = [len(x['rejected']) for x in batch]
max_length = max(max(chosen_lengths), max(rejected_lengths))
if max_length > max_seq_length:
print(
f"[WARNING] Sequences longer than {max_seq_length} tokens "
f"will be truncated."
)
max_length = min(max(max(chosen_lengths), max(rejected_lengths)), max_seq_length)
pad_to = 8
max_length_in_batch = pad_to * ((max_length + pad_to - 1) // pad_to)
max_length_in_batch = min(max_length_in_batch, max_seq_length)
chosen_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
rejected_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
chosen_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32)
rejected_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32)
# Get preference scores and convert to rewards
preference_scores = [x.get('preference_score', 1.0) for x in batch]
chosen_rewards = np.array(preference_scores, np.float32)
rejected_rewards = np.array([1.0 - score for score in preference_scores], np.float32)
for j in range(batch_size // step):
# Use pre-tokenized sequences directly
chosen_length = min(chosen_lengths[j], max_seq_length)
batch_size_per_device = batch_size // step
chosen_arr = np.zeros((batch_size_per_device, max_length_in_batch), np.int32)
rejected_arr = np.zeros((batch_size_per_device, max_length_in_batch), np.int32)
chosen_masks = np.zeros((batch_size_per_device, max_length_in_batch), np.float32)
rejected_masks = np.zeros((batch_size_per_device, max_length_in_batch), np.float32)
preference_scores = np.array([x.get('preference_score', 1.0) for x in batch], np.float32)
for j in range(batch_size_per_device):
chosen_length = min(chosen_lengths[j], max_length_in_batch)
rejected_length = min(rejected_lengths[j], max_length_in_batch)
chosen_arr[j, :chosen_length] = batch[j]['chosen'][:chosen_length]
chosen_masks[j, :chosen_length] = 1.0
rejected_length = min(rejected_lengths[j], max_seq_length)
rejected_arr[j, :rejected_length] = batch[j]['rejected'][:rejected_length]
rejected_masks[j, :rejected_length] = 1.0
yield (
mx.array(chosen_arr),
mx.array(rejected_arr),
mx.array(chosen_masks),
mx.array(rejected_masks),
mx.array(chosen_rewards),
mx.array(rejected_rewards)
mx.array(preference_scores)
)
if not train:
break
def evaluate_orpo(model, dataset, tokenizer, batch_size, num_batches, beta: float, max_seq_length=2048):
def evaluate_orpo(model, dataset, batch_size, num_batches, beta: float, max_seq_length=2048):
all_losses = 0
all_rewards = mx.zeros((2,))
all_metrics = None
@ -145,20 +127,18 @@ def evaluate_orpo(model, dataset, tokenizer, batch_size, num_batches, beta: floa
index_iterator,
iterate_orpo_batches(
dataset=dataset,
tokenizer=tokenizer,
batch_size=batch_size,
max_seq_length=max_seq_length,
),
):
chosen, rejected, chosen_masks, rejected_masks, chosen_rewards, rejected_rewards = batch
chosen, rejected, chosen_masks, rejected_masks, preference_scores = batch
loss, reward, toks, metrics = orpo_loss(
model=model,
chosen=chosen,
rejected=rejected,
chosen_masks=chosen_masks,
rejected_masks=rejected_masks,
chosen_rewards=chosen_rewards,
rejected_rewards=rejected_rewards,
preference_scores=preference_scores,
beta=beta
)
all_losses += loss * toks
@ -207,7 +187,7 @@ def train_orpo(
state = [model.state, optimizer.state]
def step(batch):
chosen, rejected, chosen_masks, rejected_masks, chosen_rewards, rejected_rewards = batch
chosen, rejected, chosen_masks, rejected_masks, preference_scores = batch
(loss, reward, toks, metrics), grad = loss_value_and_grad(
model,
@ -215,8 +195,7 @@ def train_orpo(
rejected,
chosen_masks,
rejected_masks,
chosen_rewards,
rejected_rewards
preference_scores=preference_scores,
)
grad = average_gradients(grad)
@ -224,16 +203,14 @@ def train_orpo(
return loss, reward, toks, metrics
def loss_wrapper(model, chosen, rejected, chosen_masks, rejected_masks,
chosen_rewards, rejected_rewards):
def loss_wrapper(model, chosen, rejected, chosen_masks, rejected_masks, preference_scores):
return orpo_loss(
model=model,
chosen=chosen,
rejected=rejected,
chosen_masks=chosen_masks,
rejected_masks=rejected_masks,
chosen_rewards=chosen_rewards,
rejected_rewards=rejected_rewards,
preference_scores=preference_scores,
beta=args.beta
)