diff --git a/llms/mlx_lm/tuner/orpo_trainer.py b/llms/mlx_lm/tuner/orpo_trainer.py index 00066df9..4eaace31 100644 --- a/llms/mlx_lm/tuner/orpo_trainer.py +++ b/llms/mlx_lm/tuner/orpo_trainer.py @@ -7,7 +7,7 @@ import mlx.core as mx import numpy as np from mlx.utils import tree_flatten from mlx.nn.utils import average_gradients -from .trainer import TrainingArgs, grad_checkpoint +from .trainer import TrainingArgs, grad_checkpoint, TrainingCallback @dataclass