This commit is contained in:
Goekdeniz-Guelmez 2025-01-19 02:03:31 +01:00
parent fa80d081f2
commit 9ede9db19b

View File

@ -7,11 +7,15 @@ import mlx.core as mx
import numpy as np
from mlx.utils import tree_flatten
from mlx.nn.utils import average_gradients
from .dpo_trainer import DPOTrainingArgs, grad_checkpoint
from .trainer import TrainingArgs, grad_checkpoint
@dataclass
class ORPOTrainingArgs(DPOTrainingArgs):
class ORPOTrainingArgs(TrainingArgs):
beta: float = field(
default=0.1,
metadata={"help": "Temperature parameter for DPO training."}
)
reward_scaling: float = field(
default=1.0,
metadata={"help": "Scaling factor for offline rewards."}
@ -24,8 +28,8 @@ def orpo_loss(
rejected: mx.array,
chosen_masks: mx.array,
rejected_masks: mx.array,
chosen_rewards: mx.array, # Pre-computed rewards
rejected_rewards: mx.array, # Pre-computed rewards
chosen_rewards: mx.array,
rejected_rewards: mx.array,
beta: float,
reward_scaling: float = 1.0,
):
@ -210,7 +214,7 @@ def train_orpo(
train_dataset,
val_dataset,
args: ORPOTrainingArgs = ORPOTrainingArgs(),
training_callback = None,
training_callback: TrainingCallback = None,
):
"""
Training function for ORPO.