mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 18:51:18 +08:00
nits
This commit is contained in:
parent
fa80d081f2
commit
9ede9db19b
@ -7,11 +7,15 @@ import mlx.core as mx
|
||||
import numpy as np
|
||||
from mlx.utils import tree_flatten
|
||||
from mlx.nn.utils import average_gradients
|
||||
from .dpo_trainer import DPOTrainingArgs, grad_checkpoint
|
||||
from .trainer import TrainingArgs, grad_checkpoint
|
||||
|
||||
|
||||
@dataclass
|
||||
class ORPOTrainingArgs(DPOTrainingArgs):
|
||||
class ORPOTrainingArgs(TrainingArgs):
|
||||
beta: float = field(
|
||||
default=0.1,
|
||||
metadata={"help": "Temperature parameter for DPO training."}
|
||||
)
|
||||
reward_scaling: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Scaling factor for offline rewards."}
|
||||
@ -24,8 +28,8 @@ def orpo_loss(
|
||||
rejected: mx.array,
|
||||
chosen_masks: mx.array,
|
||||
rejected_masks: mx.array,
|
||||
chosen_rewards: mx.array, # Pre-computed rewards
|
||||
rejected_rewards: mx.array, # Pre-computed rewards
|
||||
chosen_rewards: mx.array,
|
||||
rejected_rewards: mx.array,
|
||||
beta: float,
|
||||
reward_scaling: float = 1.0,
|
||||
):
|
||||
@ -210,7 +214,7 @@ def train_orpo(
|
||||
train_dataset,
|
||||
val_dataset,
|
||||
args: ORPOTrainingArgs = ORPOTrainingArgs(),
|
||||
training_callback = None,
|
||||
training_callback: TrainingCallback = None,
|
||||
):
|
||||
"""
|
||||
Training function for ORPO.
|
||||
|
Loading…
Reference in New Issue
Block a user