mlx-examples/llms/mlx_lm
Anupam Mediratta 607c300e18 Add Direct Preference Optimization (DPO) method
Fixes #513

Implement the Direct Preference Optimization (DPO) method as a Reinforcement Learning from Human Feedback (RLHF) example.

* **Add DPO Functions**: Add `get_batched_logps` and `dpo_loss` functions to `llms/mlx_lm/utils.py` for DPO implementation.
* **Update Training Logic**: Update `llms/mlx_lm/tuner/trainer.py` to include DPO-specific training logic, including a new `dpo_loss` function and condition to check for DPO loss in the training loop.
* **Add Configuration Options**: Add configuration options for DPO in `llms/mlx_lm/examples/lora_config.yaml`.
* **Update Documentation**: Update `llms/mlx_lm/README.md` to include instructions for using DPO.
* **Add Unit Tests**: Add `llms/tests/test_dpo.py` with unit tests for `get_batched_logps`, `dpo_loss`, and DPO-specific training logic.

---

For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/ml-explore/mlx-examples/issues/513?shareId=XXXX-XXXX-XXXX-XXXX).
2025-02-12 15:21:21 +05:30
..
examples Add Direct Preference Optimization (DPO) method 2025-02-12 15:21:21 +05:30
models fix sharding for more even number of layers (#1276) 2025-02-11 16:26:59 -08:00
tuner Add Direct Preference Optimization (DPO) method 2025-02-12 15:21:21 +05:30
__init__.py Fix detokenizer space match for quote (#1072) 2024-10-27 15:06:07 -07:00
_version.py deepseek v3 model with pipeline parallelism (#1191) 2025-01-09 15:55:53 -08:00
cache_prompt.py Fix prompt cache for models without chat template (#1250) 2025-02-06 11:10:58 -08:00
chat.py fix encoding with special tokens + chat template (#1189) 2025-01-03 10:50:59 -08:00
convert.py override dtype with quant (#1062) 2024-10-22 09:56:45 -07:00
evaluate.py fix generation evaluations (#1277) 2025-02-11 16:10:30 -08:00
fuse.py Adding full finetuning (#903) 2024-09-29 17:12:47 -07:00
generate.py Add IBM granite model (#1265) 2025-02-08 15:46:15 -08:00
gguf.py Fix export to gguf (#993) 2024-09-20 13:33:45 -07:00
LORA.md Completion only fine-tuning of instruction models with collections of HF datasets (#1103) 2025-02-09 20:12:34 -08:00
lora.py Completion only fine-tuning of instruction models with collections of HF datasets (#1103) 2025-02-09 20:12:34 -08:00
MANAGE.md Add model management functionality for local caches (#736) 2024-05-03 12:20:13 -07:00
manage.py Improvements to mlx_lm.manage (#1178) 2025-01-01 07:25:57 -08:00
MERGE.md Create executables for generate, lora, server, merge, convert (#682) 2024-04-16 16:08:49 -07:00
merge.py Create executables for generate, lora, server, merge, convert (#682) 2024-04-16 16:08:49 -07:00
py.typed Add py.typed to support PEP-561 (type-hinting) (#389) 2024-01-30 21:17:38 -08:00
README.md Add Direct Preference Optimization (DPO) method 2025-02-12 15:21:21 +05:30
requirements.txt deepseek v3 model with pipeline parallelism (#1191) 2025-01-09 15:55:53 -08:00
sample_utils.py batched min p and fix spec gen sampling (#1222) 2025-01-27 15:40:31 -08:00
SERVER.md Fix object property value in mlx_lm.server chat completions response to match OpenAI spec (#1119) 2024-11-24 16:37:37 -08:00
server.py chore(mlx-lm): support text type content in messages (#1225) 2025-01-27 17:13:50 -08:00
tokenizer_utils.py Completion only fine-tuning of instruction models with collections of HF datasets (#1103) 2025-02-09 20:12:34 -08:00
UPLOAD.md Mlx llm package (#301) 2024-01-12 10:25:56 -08:00
utils.py Add Direct Preference Optimization (DPO) method 2025-02-12 15:21:21 +05:30

Generate Text with MLX and 🤗 Hugging Face

This an example of large language model text generation that can pull models from the Hugging Face Hub.

For more information on this example, see the README in the parent directory.

This package also supports fine tuning with LoRA or QLoRA. For more information see the LoRA documentation.

Reinforcement Learning from Human Feedback (RLHF) with Direct Preference Optimization (DPO)

This package now includes an example of Reinforcement Learning from Human Feedback (RLHF) using the Direct Preference Optimization (DPO) method.

Paper

Direct Preference Optimization: Your Language Model is Secretly a Reward Model

Notes

Direct Preference Optimization (DPO): A Simplified Explanation by João Lages

Implementation examples

Possible MLX implementation

Policy and reference log probabilities:

def get_batched_logps(model, inputs, targets):
    logits, _ = model(inputs)
    logits = logits.astype(mx.float32)

    loss_mask = targets != 0
    per_token_logps = mx.take_along_axis(nn.log_softmax(logits), targets[..., None], axis=2).squeeze(2)

    return tuple((per_token_logps * loss_mask).sum(-1).split(2))

Loss:

def dpo_loss(model, beta, label_smoothing, reference_chosen_logps, reference_rejected_logps, inputs, targets):
    chosen_logps, rejected_logps = get_batched_logps(model, inputs, targets)

    pi_logratios = chosen_logps - rejected_logps
    reference_logratios = reference_chosen_logps - reference_rejected_logps

    logits = pi_logratios - reference_logratios
    losses = -nn.log_sigmoid(beta * logits) * (1.0 - label_smoothing) - nn.log_sigmoid(-beta * logits) * label_smoothing

    chosen_rewards = beta * (chosen_logps - reference_chosen_logps)
    rejected_rewards = beta * (rejected_logps - reference_rejected_logps)
    reward_accuracies = (chosen_rewards > rejected_rewards).astype(mx.float32)
    reward_margins = chosen_rewards - rejected_rewards

    ntoks = (inputs != 0).sum()

    return (
        losses.mean(),
        chosen_rewards.mean(),
        rejected_rewards.mean(),
        reward_accuracies.mean(),
        reward_margins.mean(),
        ntoks,
    )

Beta: The temperature parameter for the DPO loss is typically set in the range of 0.1 to 0.5. The reference model is ignored when beta equals 0.

Label smoothing: This parameter represents the conservativeness for DPO loss, assuming that preferences are noisy and can be flipped with a probability of label_smoothing.

Note

label_smoothing > 0 defines the Conservative DPO loss.

Usage Instructions

To use the Direct Preference Optimization (DPO) method in your training, follow these steps:

  1. Add Configuration Options: Update your configuration file (e.g., llms/mlx_lm/examples/lora_config.yaml) to include the DPO-specific options:

    loss_type: "dpo"
    beta: 0.1
    label_smoothing: 0.0
    
  2. Implement DPO Functions: Ensure that the get_batched_logps and dpo_loss functions are implemented in your llms/mlx_lm/utils.py file.

  3. Update Training Logic: Modify your training script (e.g., llms/mlx_lm/tuner/trainer.py) to include DPO-specific training logic. This involves updating the train function to check for the DPO loss type and apply the DPO loss calculation accordingly.

  4. Run Training: Execute your training script with the updated configuration and logic to train your model using the DPO method.

By following these steps, you can leverage the Direct Preference Optimization (DPO) method for Reinforcement Learning from Human Feedback (RLHF) in your MLX training pipeline.