This commit is contained in:
Goekdeniz-Guelmez 2025-01-19 13:45:33 +01:00
parent ea0d11cd2f
commit 363bde634e
5 changed files with 55 additions and 36 deletions

View File

@ -110,7 +110,7 @@ Here's the equivalent ORPO documentation:
### ORPO Training ### ORPO Training
Offline Reward Policy Optimization (ORPO) training allows you to fine-tune models using human preference data with pre-computed rewards. To use ORPO training, set the training mode to 'orpo': Odds Ratio Preference Optimization (ORPO) training allows you to fine-tune models using human preference data with pre-computed rewards. To use ORPO training, set the training mode to 'orpo':
```shell ```shell
mlx_lm.lora \ mlx_lm.lora \

View File

@ -70,7 +70,6 @@ CONFIG_DEFAULTS = {
"is_reference_free": False, "is_reference_free": False,
"delta": 50.0, "delta": 50.0,
"reference_model_path": None, "reference_model_path": None,
"train_bias_only": False,
"reward_scaling": 1.0, "reward_scaling": 1.0,
} }
@ -181,7 +180,6 @@ def build_parser():
parser.add_argument("--is-reference-free", action="store_true") parser.add_argument("--is-reference-free", action="store_true")
parser.add_argument("--delta", type=float) parser.add_argument("--delta", type=float)
parser.add_argument("--reference-model-path", type=str) parser.add_argument("--reference-model-path", type=str)
parser.add_argument("--train-bias-only", action="store_true")
parser.add_argument("--reward-scaling", type=float, help="Scaling factor for offline rewards.") parser.add_argument("--reward-scaling", type=float, help="Scaling factor for offline rewards.")
parser.add_argument("--seed", type=int, help="The PRNG seed.") parser.add_argument("--seed", type=int, help="The PRNG seed.")
return parser return parser
@ -247,7 +245,6 @@ def train_model(
is_reference_free=args.is_reference_free, is_reference_free=args.is_reference_free,
delta=args.delta, delta=args.delta,
reference_model_path=args.reference_model_path, reference_model_path=args.reference_model_path,
train_bias_only=args.train_bias_only,
) )
if args.reference_model_path: if args.reference_model_path:
@ -278,8 +275,6 @@ def train_model(
grad_checkpoint=args.grad_checkpoint, grad_checkpoint=args.grad_checkpoint,
beta=args.beta, beta=args.beta,
reward_scaling=args.reward_scaling, reward_scaling=args.reward_scaling,
train_bias_only=args.train_bias_only,
seed=args.seed,
) )
train_orpo( train_orpo(

View File

@ -8,8 +8,8 @@ from transformers import PreTrainedTokenizer
class DPODataset: class DPODataset:
""" """
A dataset for DPO (Direct Preference Optimization) training that handles A dataset for DPO (Direct Preference Optimization) training that handles
prompt-chosen-rejected triplets in the format: prompt-chosen-rejected triplets with optional scores in the format:
{"prompt": ..., "chosen": ..., "rejected": ...} {"prompt": ..., "chosen": ..., "rejected": ..., "score_chosen": ..., "score_rejected": ...}
""" """
def __init__( def __init__(
@ -19,31 +19,51 @@ class DPODataset:
prompt_key: str = "prompt", prompt_key: str = "prompt",
chosen_key: str = "chosen", chosen_key: str = "chosen",
rejected_key: str = "rejected", rejected_key: str = "rejected",
score_chosen_key: str = "score_chosen",
score_rejected_key: str = "score_rejected",
): ):
self._chosen_data = [ self._chosen_data = []
tokenizer.apply_chat_template( self._rejected_data = []
self._scores = []
for d in data:
# Process the text data
chosen_text = tokenizer.apply_chat_template(
[ [
{"role": "user", "content": d[prompt_key]}, {"role": "user", "content": d[prompt_key]},
{"role": "assistant", "content": d[chosen_key]}, {"role": "assistant", "content": d[chosen_key]},
], ],
) )
for d in data rejected_text = tokenizer.apply_chat_template(
]
self._rejected_data = [
tokenizer.apply_chat_template(
[ [
{"role": "user", "content": d[prompt_key]}, {"role": "user", "content": d[prompt_key]},
{"role": "assistant", "content": d[rejected_key]}, {"role": "assistant", "content": d[rejected_key]},
], ],
) )
for d in data
] self._chosen_data.append(chosen_text)
self._rejected_data.append(rejected_text)
# Handle scores if they exist
if score_chosen_key in d and score_rejected_key in d:
chosen_score = float(d[score_chosen_key])
rejected_score = float(d[score_rejected_key])
# Normalize scores to [0, 1] range
score_diff = chosen_score - rejected_score
max_diff = max(abs(score_diff), 1.0) # Avoid division by zero
normalized_score = (score_diff / max_diff + 1) / 2
self._scores.append(normalized_score)
else:
# Default to binary preference (1.0) if no scores provided
self._scores.append(1.0)
def __getitem__(self, idx: int): def __getitem__(self, idx: int):
return { return {
"chosen": self._chosen_data[idx], "chosen": self._chosen_data[idx],
"rejected": self._rejected_data[idx] "rejected": self._rejected_data[idx],
"preference_score": self._scores[idx]
} }
def __len__(self): def __len__(self):

View File

@ -45,12 +45,6 @@ class DPOTrainingArgs(TrainingArgs):
"help": "Path to reference model weights. If None, uses the same model." "help": "Path to reference model weights. If None, uses the same model."
} }
) )
train_bias_only: bool = field(
default=False,
metadata={
"help": "Whether to train only bias terms in the model."
}
)
seed: int = field( seed: int = field(
default=42, default=42,
metadata={ metadata={

View File

@ -30,19 +30,19 @@ def orpo_loss(
rejected_masks: mx.array, rejected_masks: mx.array,
chosen_rewards: mx.array, chosen_rewards: mx.array,
rejected_rewards: mx.array, rejected_rewards: mx.array,
beta: float, beta: float = 0.1,
reward_scaling: float = 1.0, reward_scaling: float = 1.0,
): ):
""" """
Calculate ORPO loss using pre-computed rewards. Calculate ORPO loss using pre-computed rewards that incorporate preference scores.
Args: Args:
model: Policy model model: Policy model
chosen: Chosen sequence tokens chosen: Chosen sequence tokens
rejected: Rejected sequence tokens rejected: Rejected sequence tokens
chosen_masks: Attention masks for chosen sequences chosen_masks: Attention masks for chosen sequences
rejected_masks: Attention masks for rejected sequences rejected_masks: Attention masks for rejected sequences
chosen_rewards: Pre-computed rewards for chosen sequences chosen_rewards: Rewards for chosen sequences (derived from preference scores)
rejected_rewards: Pre-computed rewards for rejected sequences rejected_rewards: Rewards for rejected sequences (derived from preference scores)
beta: Temperature parameter beta: Temperature parameter
reward_scaling: Scaling factor for rewards reward_scaling: Scaling factor for rewards
Returns: Returns:
@ -65,7 +65,7 @@ def orpo_loss(
chosen_rewards = chosen_rewards * reward_scaling chosen_rewards = chosen_rewards * reward_scaling
rejected_rewards = rejected_rewards * reward_scaling rejected_rewards = rejected_rewards * reward_scaling
# ORPO uses the reward difference directly # Calculate reward difference
reward_diff = chosen_rewards - rejected_rewards reward_diff = chosen_rewards - rejected_rewards
# Calculate ORPO loss using logistic function # Calculate ORPO loss using logistic function
@ -140,7 +140,7 @@ def evaluate_orpo(
def iterate_orpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): def iterate_orpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
""" """
Modified batch iterator for ORPO that includes pre-computed rewards. Modified batch iterator for ORPO that includes preference scores.
Works with pre-tokenized input data. Works with pre-tokenized input data.
""" """
# Sort pairs by length of the chosen response # Sort pairs by length of the chosen response
@ -186,9 +186,14 @@ def iterate_orpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=F
chosen_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32) chosen_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32)
rejected_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32) rejected_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32)
# Always use binary rewards # Get preference scores and convert to rewards
chosen_rewards = np.ones((batch_size // step,), np.float32) preference_scores = np.array([x.get('preference_score', 1.0) for x in batch], np.float32)
rejected_rewards = np.zeros((batch_size // step,), np.float32) # Convert preference scores to chosen/rejected rewards
# When preference_score is 1.0, chosen_reward=1.0, rejected_reward=0.0
# When preference_score is 0.0, chosen_reward=0.0, rejected_reward=1.0
# When preference_score is 0.5, both rewards are 0.5
chosen_rewards = preference_scores
rejected_rewards = 1.0 - preference_scores
for j in range(batch_size // step): for j in range(batch_size // step):
# Use pre-tokenized sequences directly # Use pre-tokenized sequences directly
@ -200,9 +205,14 @@ def iterate_orpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=F
rejected_arr[j, :rejected_length] = batch[j]['rejected'][:rejected_length] rejected_arr[j, :rejected_length] = batch[j]['rejected'][:rejected_length]
rejected_masks[j, :rejected_length] = 1.0 rejected_masks[j, :rejected_length] = 1.0
yield (mx.array(chosen_arr), mx.array(rejected_arr), yield (
mx.array(chosen_masks), mx.array(rejected_masks), mx.array(chosen_arr),
mx.array(chosen_rewards), mx.array(rejected_rewards)) mx.array(rejected_arr),
mx.array(chosen_masks),
mx.array(rejected_masks),
mx.array(chosen_rewards),
mx.array(rejected_rewards)
)
if not train: if not train:
break break