mlx-examples/llms/tests
Anupam Mediratta 607c300e18 Add Direct Preference Optimization (DPO) method
Fixes #513

Implement the Direct Preference Optimization (DPO) method as a Reinforcement Learning from Human Feedback (RLHF) example.

* **Add DPO Functions**: Add `get_batched_logps` and `dpo_loss` functions to `llms/mlx_lm/utils.py` for DPO implementation.
* **Update Training Logic**: Update `llms/mlx_lm/tuner/trainer.py` to include DPO-specific training logic, including a new `dpo_loss` function and condition to check for DPO loss in the training loop.
* **Add Configuration Options**: Add configuration options for DPO in `llms/mlx_lm/examples/lora_config.yaml`.
* **Update Documentation**: Update `llms/mlx_lm/README.md` to include instructions for using DPO.
* **Add Unit Tests**: Add `llms/tests/test_dpo.py` with unit tests for `get_batched_logps`, `dpo_loss`, and DPO-specific training logic.

---

For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/ml-explore/mlx-examples/issues/513?shareId=XXXX-XXXX-XXXX-XXXX).
2025-02-12 15:21:21 +05:30
..
test_datsets.py Completion only fine-tuning of instruction models with collections of HF datasets (#1103) 2025-02-09 20:12:34 -08:00
test_dpo.py Add Direct Preference Optimization (DPO) method 2025-02-12 15:21:21 +05:30
test_finetune.py reduction moved to CPU in case of distributed training (#1200) 2025-01-14 17:20:42 -08:00
test_generate.py Add "from_draft" to GenerationResponse (#1272) 2025-02-11 15:41:02 -08:00
test_gguf.py fix(mlx-lm): type hints in gguf.py (#621) 2024-03-26 07:56:01 -07:00
test_models.py add internlm3 (#1206) 2025-01-15 14:55:41 -08:00
test_prompt_cache.py Allow prompt callback to generate_step (#1133) 2024-12-03 16:17:14 -08:00
test_sample_utils.py batched min p and fix spec gen sampling (#1222) 2025-01-27 15:40:31 -08:00
test_server.py chore(mlx-lm): support text type content in messages (#1225) 2025-01-27 17:13:50 -08:00
test_tokenizers.py Fix decoding manually added tokens (#1164) 2024-12-17 09:54:29 -08:00
test_tuner_utils.py LoRA: Extract small function (#614) 2024-06-02 06:38:42 -07:00
test_utils_load_model.py deepseek v3 model with pipeline parallelism (#1191) 2025-01-09 15:55:53 -08:00
test_utils.py Fix whipser conversion for safetensors models (#935) 2024-08-14 10:22:04 -07:00