mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 06:22:46 +08:00

Fixes #513 Implement the Direct Preference Optimization (DPO) method as a Reinforcement Learning from Human Feedback (RLHF) example. * **Add DPO Functions**: Add `get_batched_logps` and `dpo_loss` functions to `llms/mlx_lm/utils.py` for DPO implementation. * **Update Training Logic**: Update `llms/mlx_lm/tuner/trainer.py` to include DPO-specific training logic, including a new `dpo_loss` function and condition to check for DPO loss in the training loop. * **Add Configuration Options**: Add configuration options for DPO in `llms/mlx_lm/examples/lora_config.yaml`. * **Update Documentation**: Update `llms/mlx_lm/README.md` to include instructions for using DPO. * **Add Unit Tests**: Add `llms/tests/test_dpo.py` with unit tests for `get_batched_logps`, `dpo_loss`, and DPO-specific training logic. --- For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/ml-explore/mlx-examples/issues/513?shareId=XXXX-XXXX-XXXX-XXXX).
49 lines
2.2 KiB
Python
49 lines
2.2 KiB
Python
import unittest
|
|
import numpy as np
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
from mlx_lm.utils import get_batched_logps, dpo_loss
|
|
from mlx_lm.tuner.trainer import train, TrainingArgs
|
|
from unittest.mock import MagicMock
|
|
|
|
class TestDPO(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
self.model = MagicMock()
|
|
self.inputs = mx.array([[1, 2, 3], [4, 5, 6]])
|
|
self.targets = mx.array([[1, 2, 3], [4, 5, 6]])
|
|
self.reference_chosen_logps = mx.array([0.1, 0.2])
|
|
self.reference_rejected_logps = mx.array([0.3, 0.4])
|
|
self.beta = 0.1
|
|
self.label_smoothing = 0.0
|
|
|
|
def test_get_batched_logps(self):
|
|
self.model.return_value = (mx.array([[[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]], [[0.7, 0.8], [0.9, 1.0], [1.1, 1.2]]]), None)
|
|
chosen_logps, rejected_logps = get_batched_logps(self.model, self.inputs, self.targets)
|
|
np.testing.assert_array_almost_equal(chosen_logps.asnumpy(), np.array([0.1, 0.7]))
|
|
np.testing.assert_array_almost_equal(rejected_logps.asnumpy(), np.array([0.3, 0.9]))
|
|
|
|
def test_dpo_loss(self):
|
|
self.model.return_value = (mx.array([[[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]], [[0.7, 0.8], [0.9, 1.0], [1.1, 1.2]]]), None)
|
|
loss, chosen_rewards, rejected_rewards, reward_accuracies, reward_margins, ntoks = dpo_loss(
|
|
self.model, self.beta, self.label_smoothing, self.reference_chosen_logps, self.reference_rejected_logps, self.inputs, self.targets
|
|
)
|
|
self.assertAlmostEqual(loss.item(), -0.6931472)
|
|
self.assertAlmostEqual(chosen_rewards.item(), 0.0)
|
|
self.assertAlmostEqual(rejected_rewards.item(), 0.0)
|
|
self.assertAlmostEqual(reward_accuracies.item(), 0.0)
|
|
self.assertAlmostEqual(reward_margins.item(), 0.0)
|
|
self.assertEqual(ntoks.item(), 6)
|
|
|
|
def test_train_with_dpo_loss(self):
|
|
train_dataset = MagicMock()
|
|
val_dataset = MagicMock()
|
|
tokenizer = MagicMock()
|
|
optimizer = MagicMock()
|
|
args = TrainingArgs(loss_type="dpo", beta=self.beta, label_smoothing=self.label_smoothing)
|
|
train(self.model, tokenizer, optimizer, train_dataset, val_dataset, args=args)
|
|
self.model.assert_called()
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|