mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 18:51:18 +08:00
nits
This commit is contained in:
parent
fa80d081f2
commit
9ede9db19b
@ -7,11 +7,15 @@ import mlx.core as mx
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from mlx.utils import tree_flatten
|
from mlx.utils import tree_flatten
|
||||||
from mlx.nn.utils import average_gradients
|
from mlx.nn.utils import average_gradients
|
||||||
from .dpo_trainer import DPOTrainingArgs, grad_checkpoint
|
from .trainer import TrainingArgs, grad_checkpoint
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ORPOTrainingArgs(DPOTrainingArgs):
|
class ORPOTrainingArgs(TrainingArgs):
|
||||||
|
beta: float = field(
|
||||||
|
default=0.1,
|
||||||
|
metadata={"help": "Temperature parameter for DPO training."}
|
||||||
|
)
|
||||||
reward_scaling: float = field(
|
reward_scaling: float = field(
|
||||||
default=1.0,
|
default=1.0,
|
||||||
metadata={"help": "Scaling factor for offline rewards."}
|
metadata={"help": "Scaling factor for offline rewards."}
|
||||||
@ -24,8 +28,8 @@ def orpo_loss(
|
|||||||
rejected: mx.array,
|
rejected: mx.array,
|
||||||
chosen_masks: mx.array,
|
chosen_masks: mx.array,
|
||||||
rejected_masks: mx.array,
|
rejected_masks: mx.array,
|
||||||
chosen_rewards: mx.array, # Pre-computed rewards
|
chosen_rewards: mx.array,
|
||||||
rejected_rewards: mx.array, # Pre-computed rewards
|
rejected_rewards: mx.array,
|
||||||
beta: float,
|
beta: float,
|
||||||
reward_scaling: float = 1.0,
|
reward_scaling: float = 1.0,
|
||||||
):
|
):
|
||||||
@ -210,7 +214,7 @@ def train_orpo(
|
|||||||
train_dataset,
|
train_dataset,
|
||||||
val_dataset,
|
val_dataset,
|
||||||
args: ORPOTrainingArgs = ORPOTrainingArgs(),
|
args: ORPOTrainingArgs = ORPOTrainingArgs(),
|
||||||
training_callback = None,
|
training_callback: TrainingCallback = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Training function for ORPO.
|
Training function for ORPO.
|
||||||
|
Loading…
Reference in New Issue
Block a user