From d8e783434533541abbb72e17446d3cfdda349c67 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Sat, 25 Jan 2025 21:35:37 +0100 Subject: [PATCH] Removed rejected_rewards handling, Updated batch unpacking to match iterator, Updated batch unpacking to match iterator, Added preference score scaling, Simplified reward calculation, Removed redundant rejected_rewards --- llms/mlx_lm/tuner/orpo_trainer.py | 161 +++++++++++++----------------- 1 file changed, 69 insertions(+), 92 deletions(-) diff --git a/llms/mlx_lm/tuner/orpo_trainer.py b/llms/mlx_lm/tuner/orpo_trainer.py index 66b94809..7aa07068 100644 --- a/llms/mlx_lm/tuner/orpo_trainer.py +++ b/llms/mlx_lm/tuner/orpo_trainer.py @@ -18,52 +18,50 @@ class ORPOTrainingArgs(TrainingArgs): ) -def orpo_loss(model, chosen, rejected, chosen_masks, rejected_masks, chosen_rewards, rejected_rewards, beta=0.1): - def get_logps(model, x, mask): - inputs = x[:, :-1] - targets = x[:, 1:] - logits = model(inputs) - logp = -nn.losses.cross_entropy(logits, targets, reduction='none') - seq_lengths = mask[:, :-1].sum(-1) - logp_sum = (logp * mask[:, :-1]).sum(-1) / seq_lengths - logits_mean = (logits * mask[:, :-1, None]).sum() / mask[:, :-1].sum() - return logp_sum, logits_mean +def orpo_loss(model, chosen, rejected, chosen_masks, rejected_masks, preference_scores, beta=0.1): + def get_logps(model, x, mask): + inputs = x[:, :-1] + targets = x[:, 1:] + logits = model(inputs) + logp = -nn.losses.cross_entropy(logits, targets, reduction='none') + seq_lengths = mask[:, :-1].sum(-1) + logp_sum = (logp * mask[:, :-1]).sum(-1) / seq_lengths + logits_mean = (logits * mask[:, :-1, None]).sum() / mask[:, :-1].sum() + return logp_sum, logits_mean - policy_chosen_logps, chosen_logits_mean = get_logps(model, chosen, chosen_masks) - policy_rejected_logps, rejected_logits_mean = get_logps(model, rejected, rejected_masks) - - log_odds = (policy_chosen_logps - policy_rejected_logps) - ( - mx.log1p(-mx.exp(policy_chosen_logps)) - mx.log1p(-mx.exp(policy_rejected_logps)) - ) - - ratio = nn.log_sigmoid(log_odds) - loss = -beta * ratio - - accuracies = (log_odds > 0).astype(mx.float32) - margins = mx.mean(ratio - 1) - metrics = { - 'accuracies': mx.mean(accuracies), - 'margins': margins, - 'policy_rejected_logps': mx.mean(policy_rejected_logps), - 'policy_chosen_logps': mx.mean(policy_chosen_logps), - 'rejected_logits_mean': mx.mean(rejected_logits_mean), - 'chosen_logits_mean': mx.mean(chosen_logits_mean) - } - - chosen_reward = beta * policy_chosen_logps - rejected_reward = beta * policy_rejected_logps - reward = mx.stack([mx.mean(chosen_reward), mx.mean(rejected_reward)]) - num_tokens = chosen_masks.sum() + rejected_masks.sum() - - return mx.mean(loss), reward, num_tokens, metrics + policy_chosen_logps, chosen_logits_mean = get_logps(model, chosen, chosen_masks) + policy_rejected_logps, rejected_logits_mean = get_logps(model, rejected, rejected_masks) + + # Apply preference scores + policy_chosen_logps = policy_chosen_logps * preference_scores + + log_odds = (policy_chosen_logps - policy_rejected_logps) - ( + mx.log1p(-mx.exp(policy_chosen_logps)) - mx.log1p(-mx.exp(policy_rejected_logps)) + ) + + ratio = nn.log_sigmoid(log_odds) + loss = -beta * ratio + + metrics = { + 'accuracies': mx.mean((log_odds > 0).astype(mx.float32)), + 'margins': mx.mean(ratio - 1), + 'policy_rejected_logps': mx.mean(policy_rejected_logps), + 'policy_chosen_logps': mx.mean(policy_chosen_logps), + 'rejected_logits_mean': mx.mean(rejected_logits_mean), + 'chosen_logits_mean': mx.mean(chosen_logits_mean) + } + + chosen_reward = beta * policy_chosen_logps + rejected_reward = beta * policy_rejected_logps + reward = mx.stack([mx.mean(chosen_reward), mx.mean(rejected_reward)]) + + num_tokens = chosen_masks.sum() + rejected_masks.sum() + + return mx.mean(loss), reward, num_tokens, metrics -def iterate_orpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): - """ - Modified batch iterator for ORPO that includes preference scores. - Works with pre-tokenized input data. - """ - # Sort pairs by length of the chosen response +def iterate_orpo_batches(dataset, batch_size, max_seq_length, tokenizer, train=False): + """Batch iterator for ORPO with preference scores""" idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]['chosen'])) if len(dataset) < batch_size: @@ -71,70 +69,54 @@ def iterate_orpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=F f"Dataset must have at least batch_size={batch_size}" f" examples but only has {len(dataset)}." ) - + step = mx.distributed.init().size() if batch_size % step != 0: - raise ValueError("The batch size must be divisible by the number of workers") - - batch_idx = [ - idx[i : i + batch_size : step] - for i in range(0, len(idx) - batch_size + 1, batch_size) - ] - + raise ValueError("Batch size must be divisible by number of workers") + + batch_idx = [idx[i:i + batch_size:step] for i in range(0, len(idx) - batch_size + 1, batch_size)] + while True: indices = np.random.permutation(len(batch_idx)) if train else range(len(batch_idx)) for i in indices: batch = [dataset[j] for j in batch_idx[i]] - # Get lengths assuming data is already tokenized chosen_lengths = [len(x['chosen']) for x in batch] rejected_lengths = [len(x['rejected']) for x in batch] - max_length = max(max(chosen_lengths), max(rejected_lengths)) - - if max_length > max_seq_length: - print( - f"[WARNING] Sequences longer than {max_seq_length} tokens " - f"will be truncated." - ) - + max_length = min(max(max(chosen_lengths), max(rejected_lengths)), max_seq_length) pad_to = 8 max_length_in_batch = pad_to * ((max_length + pad_to - 1) // pad_to) - max_length_in_batch = min(max_length_in_batch, max_seq_length) - - chosen_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32) - rejected_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32) - chosen_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32) - rejected_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32) - # Get preference scores and convert to rewards - preference_scores = [x.get('preference_score', 1.0) for x in batch] - chosen_rewards = np.array(preference_scores, np.float32) - rejected_rewards = np.array([1.0 - score for score in preference_scores], np.float32) - - for j in range(batch_size // step): - # Use pre-tokenized sequences directly - chosen_length = min(chosen_lengths[j], max_seq_length) + batch_size_per_device = batch_size // step + chosen_arr = np.zeros((batch_size_per_device, max_length_in_batch), np.int32) + rejected_arr = np.zeros((batch_size_per_device, max_length_in_batch), np.int32) + chosen_masks = np.zeros((batch_size_per_device, max_length_in_batch), np.float32) + rejected_masks = np.zeros((batch_size_per_device, max_length_in_batch), np.float32) + + preference_scores = np.array([x.get('preference_score', 1.0) for x in batch], np.float32) + + for j in range(batch_size_per_device): + chosen_length = min(chosen_lengths[j], max_length_in_batch) + rejected_length = min(rejected_lengths[j], max_length_in_batch) + chosen_arr[j, :chosen_length] = batch[j]['chosen'][:chosen_length] chosen_masks[j, :chosen_length] = 1.0 - - rejected_length = min(rejected_lengths[j], max_seq_length) rejected_arr[j, :rejected_length] = batch[j]['rejected'][:rejected_length] rejected_masks[j, :rejected_length] = 1.0 - + yield ( mx.array(chosen_arr), mx.array(rejected_arr), mx.array(chosen_masks), mx.array(rejected_masks), - mx.array(chosen_rewards), - mx.array(rejected_rewards) + mx.array(preference_scores) ) - + if not train: break -def evaluate_orpo(model, dataset, tokenizer, batch_size, num_batches, beta: float, max_seq_length=2048): +def evaluate_orpo(model, dataset, batch_size, num_batches, beta: float, max_seq_length=2048): all_losses = 0 all_rewards = mx.zeros((2,)) all_metrics = None @@ -145,20 +127,18 @@ def evaluate_orpo(model, dataset, tokenizer, batch_size, num_batches, beta: floa index_iterator, iterate_orpo_batches( dataset=dataset, - tokenizer=tokenizer, batch_size=batch_size, max_seq_length=max_seq_length, ), ): - chosen, rejected, chosen_masks, rejected_masks, chosen_rewards, rejected_rewards = batch + chosen, rejected, chosen_masks, rejected_masks, preference_scores = batch loss, reward, toks, metrics = orpo_loss( model=model, chosen=chosen, rejected=rejected, chosen_masks=chosen_masks, rejected_masks=rejected_masks, - chosen_rewards=chosen_rewards, - rejected_rewards=rejected_rewards, + preference_scores=preference_scores, beta=beta ) all_losses += loss * toks @@ -207,7 +187,7 @@ def train_orpo( state = [model.state, optimizer.state] def step(batch): - chosen, rejected, chosen_masks, rejected_masks, chosen_rewards, rejected_rewards = batch + chosen, rejected, chosen_masks, rejected_masks, preference_scores = batch (loss, reward, toks, metrics), grad = loss_value_and_grad( model, @@ -215,8 +195,7 @@ def train_orpo( rejected, chosen_masks, rejected_masks, - chosen_rewards, - rejected_rewards + preference_scores=preference_scores, ) grad = average_gradients(grad) @@ -224,16 +203,14 @@ def train_orpo( return loss, reward, toks, metrics - def loss_wrapper(model, chosen, rejected, chosen_masks, rejected_masks, - chosen_rewards, rejected_rewards): + def loss_wrapper(model, chosen, rejected, chosen_masks, rejected_masks, preference_scores): return orpo_loss( model=model, chosen=chosen, rejected=rejected, chosen_masks=chosen_masks, rejected_masks=rejected_masks, - chosen_rewards=chosen_rewards, - rejected_rewards=rejected_rewards, + preference_scores=preference_scores, beta=args.beta )