From 09ed837896dc2ca2393adfc1e6b68beaac049288 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Fri, 24 Jan 2025 16:57:18 +0100 Subject: [PATCH] updates --- llms/mlx_lm/LORA.md | 83 ++++++++++++++----------------- llms/mlx_lm/tuner/datasets.py | 78 +++++++++++++++++------------ llms/mlx_lm/tuner/orpo_trainer.py | 8 +-- 3 files changed, 88 insertions(+), 81 deletions(-) diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md index 93e096c3..940a473c 100644 --- a/llms/mlx_lm/LORA.md +++ b/llms/mlx_lm/LORA.md @@ -19,7 +19,6 @@ LoRA (QLoRA).[^qlora] LoRA fine-tuning works with the following model families: - [Run](#Run) - [Fine-tune](#Fine-tune) - - [DPO Training](#DPO Training) - [ORPO Training](#ORPO Training) - [Evaluate](#Evaluate) - [Generate](#Generate) @@ -79,64 +78,56 @@ You can specify the output location with `--adapter-path`. You can resume fine-tuning with an existing adapter with `--resume-adapter-file `. -### DPO Training - -Direct Preference Optimization (DPO) training allows you to fine-tune models using human preference data. To use DPO training, set the training mode to 'dpo': - -```shell -mlx_lm.lora \ - --model \ - --train \ - --training-mode dpo \ - --data \ - --beta 0.1 -``` - -The DPO training accepts the following additional parameters: - -- `--beta`: Controls the strength of the DPO loss (default: 0.1) -- `--dpo-loss-type`: Choose between "sigmoid" (default), "hinge", "ipo", or "dpop" loss functions -- `--is-reference-free`: Enable reference-free DPO training -- `--delta`: Margin parameter for hinge loss (default: 50.0) -- `--reference-model-path`: Path to a reference model for DPO training - -For DPO training, the data should be in JSONL format with the following structure: - -```jsonl -{"prompt": "User prompt", "chosen": "Preferred response", "rejected": "Less preferred response"} -``` - -Here's the equivalent ORPO documentation: - ### ORPO Training -Odds Ratio Preference Optimization (ORPO) training allows you to fine-tune models using human preference data with pre-computed rewards. To use ORPO training, set the training mode to 'orpo': +Odds Ratio Preference Optimization (ORPO) training fine-tunes models using human preference data. Usage: ```shell mlx_lm.lora \ - --model \ - --train \ - --training-mode orpo \ - --data \ - --beta 0.1 \ - --reward-scaling 1.0 + --model \ + --train \ + --training-mode orpo \ + --data \ + --beta 0.1 ``` -The ORPO training accepts the following additional parameters: -- `--beta`: Controls the temperature parameter for the logistic function (default: 0.1) -- `--reward-scaling`: Scaling factor for the offline rewards (default: 1.0) +Parameters: -For ORPO training, the data should be in JSONL format with the following structure: +- `--beta`: Temperature for logistic function (default: 0.1) + +Data format (JSONL): ```jsonl +# Basic format with string responses {"prompt": "User prompt", "chosen": "Preferred response", "rejected": "Less preferred response"} + +# With custom preference score +{"prompt": "User prompt", "chosen": "Preferred response", "rejected": "Less preferred response", "preference_score": 8.0} + +# With system message +{"prompt": "User prompt", "chosen": "Preferred response", "rejected": "Less preferred response", "system": "System instruction"} + +# With full conversation objects +{ + "prompt": "User prompt", + "chosen": { + "messages": [ + {"role": "system", "content": "System instruction"}, + {"role": "user", "content": "User message"}, + {"role": "assistant", "content": "Assistant response"} + ] + }, + "rejected": { + "messages": [ + {"role": "system", "content": "System instruction"}, + {"role": "user", "content": "User message"}, + {"role": "assistant", "content": "Assistant response"} + ] + } +} ``` -The training process will automatically assign binary rewards (1.0 for chosen and 0.0 for rejected responses) if no explicit rewards are provided. You can also provide custom rewards in your data: - -```jsonl -{"prompt": "User prompt", "chosen": "Preferred response", "rejected": "Less preferred response", "chosen_reward": 0.8, "rejected_reward": 0.3} -``` +The trainer assigns binary rewards (1.0 chosen, 0.0 rejected) if no explicit rewards provided via `preference_score`. ### Evaluate diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 0914c6b7..2330d4ac 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -1,50 +1,66 @@ import json from pathlib import Path -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union from transformers import PreTrainedTokenizer class ORPODataset: - def __init__( - self, - data: List[Dict[str, str]], - tokenizer: PreTrainedTokenizer, - prompt_key: str = "prompt", - chosen_key: str = "chosen", - rejected_key: str = "rejected", - preference_score_key: str = "preference_score" - ): + def __init__( + self, + data: List[Dict[str, Union[str, Dict]]], + tokenizer: PreTrainedTokenizer, + prompt_key: str = "prompt", + chosen_key: str = "chosen", + rejected_key: str = "rejected", + preference_score_key: str = "preference_score", + system_key: str = None + ): self._chosen_data = [] self._rejected_data = [] self._scores = [] - + for d in data: - chosen_text = tokenizer.apply_chat_template([ - {"role": "user", "content": d[prompt_key]}, - {"role": "assistant", "content": d[chosen_key]}, - ]) - rejected_text = tokenizer.apply_chat_template([ - {"role": "user", "content": d[prompt_key]}, - {"role": "assistant", "content": d[rejected_key]}, - ]) - + if system_key and system_key in d: + base_messages = [{"role": "system", "content": d[system_key]}] + chosen_messages = base_messages + [{"role": "user", "content": d[prompt_key]}] + if isinstance(d[chosen_key], str): + chosen_messages.append({"role": "assistant", "content": d[chosen_key]}) + else: + chosen_messages.extend(d[chosen_key]["messages"]) + rejected_messages = base_messages + [{"role": "user", "content": d[prompt_key]}] + if isinstance(d[rejected_key], str): + rejected_messages.append({"role": "assistant", "content": d[rejected_key]}) + else: + rejected_messages.extend(d[rejected_key]["messages"]) + chosen_text = tokenizer.apply_chat_template(chosen_messages) + rejected_text = tokenizer.apply_chat_template(rejected_messages) + else: + chosen_text = tokenizer.apply_chat_template([ + {"role": "user", "content": d[prompt_key]}, + {"role": "assistant", "content": d[chosen_key] if isinstance(d[chosen_key], str) else d[chosen_key]["messages"][-1]["content"]}, + ]) + rejected_text = tokenizer.apply_chat_template([ + {"role": "user", "content": d[prompt_key]}, + {"role": "assistant", "content": d[rejected_key] if isinstance(d[rejected_key], str) else d[rejected_key]["messages"][-1]["content"]}, + ]) + self._chosen_data.append(chosen_text) self._rejected_data.append(rejected_text) - + if preference_score_key in d: self._scores.append(float(d[preference_score_key])) else: self._scores.append(1.0) - - def __getitem__(self, idx: int): - return { - "chosen": self._chosen_data[idx], - "rejected": self._rejected_data[idx], - "preference_score": self._scores[idx] - } - - def __len__(self): - return len(self._chosen_data) + + def __len__(self): + return len(self._chosen_data) + + def __getitem__(self, idx: int): + return { + "chosen": self._chosen_data[idx], + "rejected": self._rejected_data[idx], + "preference_score": self._scores[idx] + } class Dataset: diff --git a/llms/mlx_lm/tuner/orpo_trainer.py b/llms/mlx_lm/tuner/orpo_trainer.py index 6963ca40..66b94809 100644 --- a/llms/mlx_lm/tuner/orpo_trainer.py +++ b/llms/mlx_lm/tuner/orpo_trainer.py @@ -40,7 +40,7 @@ def orpo_loss(model, chosen, rejected, chosen_masks, rejected_masks, chosen_rewa loss = -beta * ratio accuracies = (log_odds > 0).astype(mx.float32) - margins = mx.mean(ratio) + margins = mx.mean(ratio - 1) metrics = { 'accuracies': mx.mean(accuracies), 'margins': margins, @@ -107,9 +107,9 @@ def iterate_orpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=F rejected_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32) # Get preference scores and convert to rewards - preference_scores = np.array([x.get('preference_score', 1.0) for x in batch], np.float32) - chosen_rewards = preference_scores - rejected_rewards = 1.0 - preference_scores + preference_scores = [x.get('preference_score', 1.0) for x in batch] + chosen_rewards = np.array(preference_scores, np.float32) + rejected_rewards = np.array([1.0 - score for score in preference_scores], np.float32) for j in range(batch_size // step): # Use pre-tokenized sequences directly