mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 18:51:18 +08:00
Removed rejected_rewards handling, Updated batch unpacking to match iterator, Updated batch unpacking to match iterator, Added preference score scaling, Simplified reward calculation, Removed redundant rejected_rewards
This commit is contained in:
parent
09ed837896
commit
d8e7834345
@ -18,52 +18,50 @@ class ORPOTrainingArgs(TrainingArgs):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def orpo_loss(model, chosen, rejected, chosen_masks, rejected_masks, chosen_rewards, rejected_rewards, beta=0.1):
|
def orpo_loss(model, chosen, rejected, chosen_masks, rejected_masks, preference_scores, beta=0.1):
|
||||||
def get_logps(model, x, mask):
|
def get_logps(model, x, mask):
|
||||||
inputs = x[:, :-1]
|
inputs = x[:, :-1]
|
||||||
targets = x[:, 1:]
|
targets = x[:, 1:]
|
||||||
logits = model(inputs)
|
logits = model(inputs)
|
||||||
logp = -nn.losses.cross_entropy(logits, targets, reduction='none')
|
logp = -nn.losses.cross_entropy(logits, targets, reduction='none')
|
||||||
seq_lengths = mask[:, :-1].sum(-1)
|
seq_lengths = mask[:, :-1].sum(-1)
|
||||||
logp_sum = (logp * mask[:, :-1]).sum(-1) / seq_lengths
|
logp_sum = (logp * mask[:, :-1]).sum(-1) / seq_lengths
|
||||||
logits_mean = (logits * mask[:, :-1, None]).sum() / mask[:, :-1].sum()
|
logits_mean = (logits * mask[:, :-1, None]).sum() / mask[:, :-1].sum()
|
||||||
return logp_sum, logits_mean
|
return logp_sum, logits_mean
|
||||||
|
|
||||||
policy_chosen_logps, chosen_logits_mean = get_logps(model, chosen, chosen_masks)
|
policy_chosen_logps, chosen_logits_mean = get_logps(model, chosen, chosen_masks)
|
||||||
policy_rejected_logps, rejected_logits_mean = get_logps(model, rejected, rejected_masks)
|
policy_rejected_logps, rejected_logits_mean = get_logps(model, rejected, rejected_masks)
|
||||||
|
|
||||||
log_odds = (policy_chosen_logps - policy_rejected_logps) - (
|
# Apply preference scores
|
||||||
mx.log1p(-mx.exp(policy_chosen_logps)) - mx.log1p(-mx.exp(policy_rejected_logps))
|
policy_chosen_logps = policy_chosen_logps * preference_scores
|
||||||
)
|
|
||||||
|
log_odds = (policy_chosen_logps - policy_rejected_logps) - (
|
||||||
ratio = nn.log_sigmoid(log_odds)
|
mx.log1p(-mx.exp(policy_chosen_logps)) - mx.log1p(-mx.exp(policy_rejected_logps))
|
||||||
loss = -beta * ratio
|
)
|
||||||
|
|
||||||
accuracies = (log_odds > 0).astype(mx.float32)
|
ratio = nn.log_sigmoid(log_odds)
|
||||||
margins = mx.mean(ratio - 1)
|
loss = -beta * ratio
|
||||||
metrics = {
|
|
||||||
'accuracies': mx.mean(accuracies),
|
metrics = {
|
||||||
'margins': margins,
|
'accuracies': mx.mean((log_odds > 0).astype(mx.float32)),
|
||||||
'policy_rejected_logps': mx.mean(policy_rejected_logps),
|
'margins': mx.mean(ratio - 1),
|
||||||
'policy_chosen_logps': mx.mean(policy_chosen_logps),
|
'policy_rejected_logps': mx.mean(policy_rejected_logps),
|
||||||
'rejected_logits_mean': mx.mean(rejected_logits_mean),
|
'policy_chosen_logps': mx.mean(policy_chosen_logps),
|
||||||
'chosen_logits_mean': mx.mean(chosen_logits_mean)
|
'rejected_logits_mean': mx.mean(rejected_logits_mean),
|
||||||
}
|
'chosen_logits_mean': mx.mean(chosen_logits_mean)
|
||||||
|
}
|
||||||
chosen_reward = beta * policy_chosen_logps
|
|
||||||
rejected_reward = beta * policy_rejected_logps
|
chosen_reward = beta * policy_chosen_logps
|
||||||
reward = mx.stack([mx.mean(chosen_reward), mx.mean(rejected_reward)])
|
rejected_reward = beta * policy_rejected_logps
|
||||||
num_tokens = chosen_masks.sum() + rejected_masks.sum()
|
reward = mx.stack([mx.mean(chosen_reward), mx.mean(rejected_reward)])
|
||||||
|
|
||||||
return mx.mean(loss), reward, num_tokens, metrics
|
num_tokens = chosen_masks.sum() + rejected_masks.sum()
|
||||||
|
|
||||||
|
return mx.mean(loss), reward, num_tokens, metrics
|
||||||
|
|
||||||
|
|
||||||
def iterate_orpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
def iterate_orpo_batches(dataset, batch_size, max_seq_length, tokenizer, train=False):
|
||||||
"""
|
"""Batch iterator for ORPO with preference scores"""
|
||||||
Modified batch iterator for ORPO that includes preference scores.
|
|
||||||
Works with pre-tokenized input data.
|
|
||||||
"""
|
|
||||||
# Sort pairs by length of the chosen response
|
|
||||||
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]['chosen']))
|
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]['chosen']))
|
||||||
|
|
||||||
if len(dataset) < batch_size:
|
if len(dataset) < batch_size:
|
||||||
@ -71,70 +69,54 @@ def iterate_orpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=F
|
|||||||
f"Dataset must have at least batch_size={batch_size}"
|
f"Dataset must have at least batch_size={batch_size}"
|
||||||
f" examples but only has {len(dataset)}."
|
f" examples but only has {len(dataset)}."
|
||||||
)
|
)
|
||||||
|
|
||||||
step = mx.distributed.init().size()
|
step = mx.distributed.init().size()
|
||||||
if batch_size % step != 0:
|
if batch_size % step != 0:
|
||||||
raise ValueError("The batch size must be divisible by the number of workers")
|
raise ValueError("Batch size must be divisible by number of workers")
|
||||||
|
|
||||||
batch_idx = [
|
batch_idx = [idx[i:i + batch_size:step] for i in range(0, len(idx) - batch_size + 1, batch_size)]
|
||||||
idx[i : i + batch_size : step]
|
|
||||||
for i in range(0, len(idx) - batch_size + 1, batch_size)
|
|
||||||
]
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
indices = np.random.permutation(len(batch_idx)) if train else range(len(batch_idx))
|
indices = np.random.permutation(len(batch_idx)) if train else range(len(batch_idx))
|
||||||
for i in indices:
|
for i in indices:
|
||||||
batch = [dataset[j] for j in batch_idx[i]]
|
batch = [dataset[j] for j in batch_idx[i]]
|
||||||
|
|
||||||
# Get lengths assuming data is already tokenized
|
|
||||||
chosen_lengths = [len(x['chosen']) for x in batch]
|
chosen_lengths = [len(x['chosen']) for x in batch]
|
||||||
rejected_lengths = [len(x['rejected']) for x in batch]
|
rejected_lengths = [len(x['rejected']) for x in batch]
|
||||||
max_length = max(max(chosen_lengths), max(rejected_lengths))
|
max_length = min(max(max(chosen_lengths), max(rejected_lengths)), max_seq_length)
|
||||||
|
|
||||||
if max_length > max_seq_length:
|
|
||||||
print(
|
|
||||||
f"[WARNING] Sequences longer than {max_seq_length} tokens "
|
|
||||||
f"will be truncated."
|
|
||||||
)
|
|
||||||
|
|
||||||
pad_to = 8
|
pad_to = 8
|
||||||
max_length_in_batch = pad_to * ((max_length + pad_to - 1) // pad_to)
|
max_length_in_batch = pad_to * ((max_length + pad_to - 1) // pad_to)
|
||||||
max_length_in_batch = min(max_length_in_batch, max_seq_length)
|
|
||||||
|
|
||||||
chosen_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
|
|
||||||
rejected_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
|
|
||||||
chosen_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32)
|
|
||||||
rejected_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32)
|
|
||||||
|
|
||||||
# Get preference scores and convert to rewards
|
batch_size_per_device = batch_size // step
|
||||||
preference_scores = [x.get('preference_score', 1.0) for x in batch]
|
chosen_arr = np.zeros((batch_size_per_device, max_length_in_batch), np.int32)
|
||||||
chosen_rewards = np.array(preference_scores, np.float32)
|
rejected_arr = np.zeros((batch_size_per_device, max_length_in_batch), np.int32)
|
||||||
rejected_rewards = np.array([1.0 - score for score in preference_scores], np.float32)
|
chosen_masks = np.zeros((batch_size_per_device, max_length_in_batch), np.float32)
|
||||||
|
rejected_masks = np.zeros((batch_size_per_device, max_length_in_batch), np.float32)
|
||||||
for j in range(batch_size // step):
|
|
||||||
# Use pre-tokenized sequences directly
|
preference_scores = np.array([x.get('preference_score', 1.0) for x in batch], np.float32)
|
||||||
chosen_length = min(chosen_lengths[j], max_seq_length)
|
|
||||||
|
for j in range(batch_size_per_device):
|
||||||
|
chosen_length = min(chosen_lengths[j], max_length_in_batch)
|
||||||
|
rejected_length = min(rejected_lengths[j], max_length_in_batch)
|
||||||
|
|
||||||
chosen_arr[j, :chosen_length] = batch[j]['chosen'][:chosen_length]
|
chosen_arr[j, :chosen_length] = batch[j]['chosen'][:chosen_length]
|
||||||
chosen_masks[j, :chosen_length] = 1.0
|
chosen_masks[j, :chosen_length] = 1.0
|
||||||
|
|
||||||
rejected_length = min(rejected_lengths[j], max_seq_length)
|
|
||||||
rejected_arr[j, :rejected_length] = batch[j]['rejected'][:rejected_length]
|
rejected_arr[j, :rejected_length] = batch[j]['rejected'][:rejected_length]
|
||||||
rejected_masks[j, :rejected_length] = 1.0
|
rejected_masks[j, :rejected_length] = 1.0
|
||||||
|
|
||||||
yield (
|
yield (
|
||||||
mx.array(chosen_arr),
|
mx.array(chosen_arr),
|
||||||
mx.array(rejected_arr),
|
mx.array(rejected_arr),
|
||||||
mx.array(chosen_masks),
|
mx.array(chosen_masks),
|
||||||
mx.array(rejected_masks),
|
mx.array(rejected_masks),
|
||||||
mx.array(chosen_rewards),
|
mx.array(preference_scores)
|
||||||
mx.array(rejected_rewards)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not train:
|
if not train:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
def evaluate_orpo(model, dataset, tokenizer, batch_size, num_batches, beta: float, max_seq_length=2048):
|
def evaluate_orpo(model, dataset, batch_size, num_batches, beta: float, max_seq_length=2048):
|
||||||
all_losses = 0
|
all_losses = 0
|
||||||
all_rewards = mx.zeros((2,))
|
all_rewards = mx.zeros((2,))
|
||||||
all_metrics = None
|
all_metrics = None
|
||||||
@ -145,20 +127,18 @@ def evaluate_orpo(model, dataset, tokenizer, batch_size, num_batches, beta: floa
|
|||||||
index_iterator,
|
index_iterator,
|
||||||
iterate_orpo_batches(
|
iterate_orpo_batches(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
tokenizer=tokenizer,
|
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
max_seq_length=max_seq_length,
|
max_seq_length=max_seq_length,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
chosen, rejected, chosen_masks, rejected_masks, chosen_rewards, rejected_rewards = batch
|
chosen, rejected, chosen_masks, rejected_masks, preference_scores = batch
|
||||||
loss, reward, toks, metrics = orpo_loss(
|
loss, reward, toks, metrics = orpo_loss(
|
||||||
model=model,
|
model=model,
|
||||||
chosen=chosen,
|
chosen=chosen,
|
||||||
rejected=rejected,
|
rejected=rejected,
|
||||||
chosen_masks=chosen_masks,
|
chosen_masks=chosen_masks,
|
||||||
rejected_masks=rejected_masks,
|
rejected_masks=rejected_masks,
|
||||||
chosen_rewards=chosen_rewards,
|
preference_scores=preference_scores,
|
||||||
rejected_rewards=rejected_rewards,
|
|
||||||
beta=beta
|
beta=beta
|
||||||
)
|
)
|
||||||
all_losses += loss * toks
|
all_losses += loss * toks
|
||||||
@ -207,7 +187,7 @@ def train_orpo(
|
|||||||
state = [model.state, optimizer.state]
|
state = [model.state, optimizer.state]
|
||||||
|
|
||||||
def step(batch):
|
def step(batch):
|
||||||
chosen, rejected, chosen_masks, rejected_masks, chosen_rewards, rejected_rewards = batch
|
chosen, rejected, chosen_masks, rejected_masks, preference_scores = batch
|
||||||
|
|
||||||
(loss, reward, toks, metrics), grad = loss_value_and_grad(
|
(loss, reward, toks, metrics), grad = loss_value_and_grad(
|
||||||
model,
|
model,
|
||||||
@ -215,8 +195,7 @@ def train_orpo(
|
|||||||
rejected,
|
rejected,
|
||||||
chosen_masks,
|
chosen_masks,
|
||||||
rejected_masks,
|
rejected_masks,
|
||||||
chosen_rewards,
|
preference_scores=preference_scores,
|
||||||
rejected_rewards
|
|
||||||
)
|
)
|
||||||
|
|
||||||
grad = average_gradients(grad)
|
grad = average_gradients(grad)
|
||||||
@ -224,16 +203,14 @@ def train_orpo(
|
|||||||
|
|
||||||
return loss, reward, toks, metrics
|
return loss, reward, toks, metrics
|
||||||
|
|
||||||
def loss_wrapper(model, chosen, rejected, chosen_masks, rejected_masks,
|
def loss_wrapper(model, chosen, rejected, chosen_masks, rejected_masks, preference_scores):
|
||||||
chosen_rewards, rejected_rewards):
|
|
||||||
return orpo_loss(
|
return orpo_loss(
|
||||||
model=model,
|
model=model,
|
||||||
chosen=chosen,
|
chosen=chosen,
|
||||||
rejected=rejected,
|
rejected=rejected,
|
||||||
chosen_masks=chosen_masks,
|
chosen_masks=chosen_masks,
|
||||||
rejected_masks=rejected_masks,
|
rejected_masks=rejected_masks,
|
||||||
chosen_rewards=chosen_rewards,
|
preference_scores=preference_scores,
|
||||||
rejected_rewards=rejected_rewards,
|
|
||||||
beta=args.beta
|
beta=args.beta
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user