This commit is contained in:
Goekdeniz-Guelmez 2025-01-19 02:03:31 +01:00
parent fa80d081f2
commit 9ede9db19b

View File

@ -7,11 +7,15 @@ import mlx.core as mx
import numpy as np import numpy as np
from mlx.utils import tree_flatten from mlx.utils import tree_flatten
from mlx.nn.utils import average_gradients from mlx.nn.utils import average_gradients
from .dpo_trainer import DPOTrainingArgs, grad_checkpoint from .trainer import TrainingArgs, grad_checkpoint
@dataclass @dataclass
class ORPOTrainingArgs(DPOTrainingArgs): class ORPOTrainingArgs(TrainingArgs):
beta: float = field(
default=0.1,
metadata={"help": "Temperature parameter for DPO training."}
)
reward_scaling: float = field( reward_scaling: float = field(
default=1.0, default=1.0,
metadata={"help": "Scaling factor for offline rewards."} metadata={"help": "Scaling factor for offline rewards."}
@ -24,8 +28,8 @@ def orpo_loss(
rejected: mx.array, rejected: mx.array,
chosen_masks: mx.array, chosen_masks: mx.array,
rejected_masks: mx.array, rejected_masks: mx.array,
chosen_rewards: mx.array, # Pre-computed rewards chosen_rewards: mx.array,
rejected_rewards: mx.array, # Pre-computed rewards rejected_rewards: mx.array,
beta: float, beta: float,
reward_scaling: float = 1.0, reward_scaling: float = 1.0,
): ):
@ -210,7 +214,7 @@ def train_orpo(
train_dataset, train_dataset,
val_dataset, val_dataset,
args: ORPOTrainingArgs = ORPOTrainingArgs(), args: ORPOTrainingArgs = ORPOTrainingArgs(),
training_callback = None, training_callback: TrainingCallback = None,
): ):
""" """
Training function for ORPO. Training function for ORPO.