From 9ede9db19b5fb2350672b9666dede13d559ef3dc Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Sun, 19 Jan 2025 02:03:31 +0100 Subject: [PATCH] nits --- llms/mlx_lm/tuner/orpo_trainer.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/llms/mlx_lm/tuner/orpo_trainer.py b/llms/mlx_lm/tuner/orpo_trainer.py index 984fb3b1..00066df9 100644 --- a/llms/mlx_lm/tuner/orpo_trainer.py +++ b/llms/mlx_lm/tuner/orpo_trainer.py @@ -7,11 +7,15 @@ import mlx.core as mx import numpy as np from mlx.utils import tree_flatten from mlx.nn.utils import average_gradients -from .dpo_trainer import DPOTrainingArgs, grad_checkpoint +from .trainer import TrainingArgs, grad_checkpoint @dataclass -class ORPOTrainingArgs(DPOTrainingArgs): +class ORPOTrainingArgs(TrainingArgs): + beta: float = field( + default=0.1, + metadata={"help": "Temperature parameter for DPO training."} + ) reward_scaling: float = field( default=1.0, metadata={"help": "Scaling factor for offline rewards."} @@ -24,8 +28,8 @@ def orpo_loss( rejected: mx.array, chosen_masks: mx.array, rejected_masks: mx.array, - chosen_rewards: mx.array, # Pre-computed rewards - rejected_rewards: mx.array, # Pre-computed rewards + chosen_rewards: mx.array, + rejected_rewards: mx.array, beta: float, reward_scaling: float = 1.0, ): @@ -210,7 +214,7 @@ def train_orpo( train_dataset, val_dataset, args: ORPOTrainingArgs = ORPOTrainingArgs(), - training_callback = None, + training_callback: TrainingCallback = None, ): """ Training function for ORPO.