diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 9f5427a9..64be0a91 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -66,7 +66,6 @@ CONFIG_DEFAULTS = { "lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}, "beta": 0.1, "dpo_loss_type": "sigmoid", - "is_reference_free": False, "delta": 50.0, "reference_model_path": None, "reward_scaling": 1.0, @@ -174,13 +173,21 @@ def build_parser(): help="Use gradient checkpointing to reduce memory use.", default=None, ) - parser.add_argument("--beta", type=float) - parser.add_argument("--dpo-loss-type", type=str, choices=["sigmoid", "hinge", "ipo", "dpo"]) - parser.add_argument("--is-reference-free", action="store_true") - parser.add_argument("--delta", type=float) - parser.add_argument("--reference-model-path", type=str) - parser.add_argument("--reward-scaling", type=float, help="Scaling factor for offline rewards.") parser.add_argument("--seed", type=int, help="The PRNG seed.") + + # ORPO args + parser.add_argument( + "--beta", + type=float, + help="Temperature parameter for ORPO training.", + default=0.1 + ) + parser.add_argument( + "--reward-scaling", + type=float, + help="Reward scaling factor for ORPO training, not implemented.", + default=1.0 + ) return parser @@ -239,7 +246,8 @@ def train_model( adapter_file=adapter_file, max_seq_length=args.max_seq_length, grad_checkpoint=args.grad_checkpoint, - beta=args.beta + beta=args.beta, + reward_scaling=args.reward_scaling ) train_orpo( @@ -288,7 +296,7 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set max_seq_length=args.max_seq_length, beta=args.beta ) - print(f"Test loss {test_loss:.8f}, Rewards: {test_rewards[0]:.3f}, {test_rewards[1]:.3f}") + print(f"Test loss {test_loss:.3f}, Rewards: {test_rewards[0]:.3f}, {test_rewards[1]:.3f}") else: test_loss = evaluate( model=model, @@ -351,4 +359,4 @@ def main(): if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/llms/mlx_lm/tuner/orpo_trainer.py b/llms/mlx_lm/tuner/orpo_trainer.py index fb38c1e1..36dd48b9 100644 --- a/llms/mlx_lm/tuner/orpo_trainer.py +++ b/llms/mlx_lm/tuner/orpo_trainer.py @@ -16,6 +16,10 @@ class ORPOTrainingArgs(TrainingArgs): default=0.1, metadata={"help": "Temperature parameter for ORPO training."} ) + reward_scaling: float = field( + default=1.0, + metadata={"help": "Reward scaling factor for ORPO training, not implemented."} + ) def orpo_loss(model, chosen, rejected, chosen_masks, rejected_masks, preference_scores, beta=0.1): @@ -131,7 +135,7 @@ def evaluate_orpo(model, dataset, batch_size, num_batches, beta: float, max_seq_ ), ): chosen, rejected, chosen_masks, rejected_masks, preference_scores = batch - loss, reward, toks, metrics = orpo_loss( + lvalue, reward, toks, metrics = orpo_loss( model=model, chosen=chosen, rejected=rejected, @@ -140,7 +144,7 @@ def evaluate_orpo(model, dataset, batch_size, num_batches, beta: float, max_seq_ preference_scores=preference_scores, beta=beta ) - all_losses += loss * toks + all_losses += lvalue * toks all_rewards += reward * toks ntokens += toks @@ -169,6 +173,7 @@ def train_orpo( optimizer, train_dataset, val_dataset, + loss: callable = orpo_loss, args: ORPOTrainingArgs = ORPOTrainingArgs(), training_callback: TrainingCallback = None, ): @@ -188,7 +193,7 @@ def train_orpo( def step(batch): chosen, rejected, chosen_masks, rejected_masks, preference_scores = batch - (loss, reward, toks, metrics), grad = loss_value_and_grad( + (lvalue, reward, toks, metrics), grad = loss_value_and_grad( model, chosen, rejected, @@ -200,10 +205,10 @@ def train_orpo( grad = average_gradients(grad) optimizer.update(model, grad) - return loss, reward, toks, metrics + return lvalue, reward, toks, metrics def loss_wrapper(model, chosen, rejected, chosen_masks, rejected_masks, preference_scores): - return orpo_loss( + return loss( model=model, chosen=chosen, rejected=rejected, @@ -254,7 +259,7 @@ def train_orpo( if rank == 0: print( f"Iter {it}: " - f"Val loss {val_loss:.8f}, " + f"Val loss {val_loss:.3f}, " f"Val chosen reward {val_rewards[0]:.3f}, " f"Val rejected reward {val_rewards[1]:.3f}, " f"Val accuracy {val_metrics['accuracies']:.3f}, " @@ -276,13 +281,15 @@ def train_orpo( start = time.perf_counter() # Training step - loss, reward, toks, metrics = step(batch) - losses += loss + lvalue, reward, toks, metrics = step(batch) + losses += lvalue rewards += reward n_tokens += toks steps += 1 + for k, v in metrics.items(): accumulated_metrics[k] += v + mx.eval(state, losses, rewards, n_tokens) if it % args.steps_per_report == 0 or it == args.iters: @@ -300,7 +307,7 @@ def train_orpo( if rank == 0: print( - f"Iter {it}: Train loss {train_loss:.8f}, " + f"Iter {it}: Train loss {train_loss:.3f}, " f"Chosen reward {train_rewards[0]:.3f}, " f"Rejected reward {train_rewards[1]:.3f}, " f"Accuracy {avg_metrics['accuracies']:.3f}, "