mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-27 03:05:20 +08:00
fixes
This commit is contained in:
parent
ea0d11cd2f
commit
363bde634e
@ -110,7 +110,7 @@ Here's the equivalent ORPO documentation:
|
|||||||
|
|
||||||
### ORPO Training
|
### ORPO Training
|
||||||
|
|
||||||
Offline Reward Policy Optimization (ORPO) training allows you to fine-tune models using human preference data with pre-computed rewards. To use ORPO training, set the training mode to 'orpo':
|
Odds Ratio Preference Optimization (ORPO) training allows you to fine-tune models using human preference data with pre-computed rewards. To use ORPO training, set the training mode to 'orpo':
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
mlx_lm.lora \
|
mlx_lm.lora \
|
||||||
|
@ -70,7 +70,6 @@ CONFIG_DEFAULTS = {
|
|||||||
"is_reference_free": False,
|
"is_reference_free": False,
|
||||||
"delta": 50.0,
|
"delta": 50.0,
|
||||||
"reference_model_path": None,
|
"reference_model_path": None,
|
||||||
"train_bias_only": False,
|
|
||||||
"reward_scaling": 1.0,
|
"reward_scaling": 1.0,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -181,7 +180,6 @@ def build_parser():
|
|||||||
parser.add_argument("--is-reference-free", action="store_true")
|
parser.add_argument("--is-reference-free", action="store_true")
|
||||||
parser.add_argument("--delta", type=float)
|
parser.add_argument("--delta", type=float)
|
||||||
parser.add_argument("--reference-model-path", type=str)
|
parser.add_argument("--reference-model-path", type=str)
|
||||||
parser.add_argument("--train-bias-only", action="store_true")
|
|
||||||
parser.add_argument("--reward-scaling", type=float, help="Scaling factor for offline rewards.")
|
parser.add_argument("--reward-scaling", type=float, help="Scaling factor for offline rewards.")
|
||||||
parser.add_argument("--seed", type=int, help="The PRNG seed.")
|
parser.add_argument("--seed", type=int, help="The PRNG seed.")
|
||||||
return parser
|
return parser
|
||||||
@ -247,7 +245,6 @@ def train_model(
|
|||||||
is_reference_free=args.is_reference_free,
|
is_reference_free=args.is_reference_free,
|
||||||
delta=args.delta,
|
delta=args.delta,
|
||||||
reference_model_path=args.reference_model_path,
|
reference_model_path=args.reference_model_path,
|
||||||
train_bias_only=args.train_bias_only,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.reference_model_path:
|
if args.reference_model_path:
|
||||||
@ -278,8 +275,6 @@ def train_model(
|
|||||||
grad_checkpoint=args.grad_checkpoint,
|
grad_checkpoint=args.grad_checkpoint,
|
||||||
beta=args.beta,
|
beta=args.beta,
|
||||||
reward_scaling=args.reward_scaling,
|
reward_scaling=args.reward_scaling,
|
||||||
train_bias_only=args.train_bias_only,
|
|
||||||
seed=args.seed,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
train_orpo(
|
train_orpo(
|
||||||
|
@ -8,8 +8,8 @@ from transformers import PreTrainedTokenizer
|
|||||||
class DPODataset:
|
class DPODataset:
|
||||||
"""
|
"""
|
||||||
A dataset for DPO (Direct Preference Optimization) training that handles
|
A dataset for DPO (Direct Preference Optimization) training that handles
|
||||||
prompt-chosen-rejected triplets in the format:
|
prompt-chosen-rejected triplets with optional scores in the format:
|
||||||
{"prompt": ..., "chosen": ..., "rejected": ...}
|
{"prompt": ..., "chosen": ..., "rejected": ..., "score_chosen": ..., "score_rejected": ...}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -19,31 +19,51 @@ class DPODataset:
|
|||||||
prompt_key: str = "prompt",
|
prompt_key: str = "prompt",
|
||||||
chosen_key: str = "chosen",
|
chosen_key: str = "chosen",
|
||||||
rejected_key: str = "rejected",
|
rejected_key: str = "rejected",
|
||||||
|
score_chosen_key: str = "score_chosen",
|
||||||
|
score_rejected_key: str = "score_rejected",
|
||||||
):
|
):
|
||||||
self._chosen_data = [
|
self._chosen_data = []
|
||||||
tokenizer.apply_chat_template(
|
self._rejected_data = []
|
||||||
|
self._scores = []
|
||||||
|
|
||||||
|
for d in data:
|
||||||
|
# Process the text data
|
||||||
|
chosen_text = tokenizer.apply_chat_template(
|
||||||
[
|
[
|
||||||
{"role": "user", "content": d[prompt_key]},
|
{"role": "user", "content": d[prompt_key]},
|
||||||
{"role": "assistant", "content": d[chosen_key]},
|
{"role": "assistant", "content": d[chosen_key]},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
for d in data
|
rejected_text = tokenizer.apply_chat_template(
|
||||||
]
|
|
||||||
|
|
||||||
self._rejected_data = [
|
|
||||||
tokenizer.apply_chat_template(
|
|
||||||
[
|
[
|
||||||
{"role": "user", "content": d[prompt_key]},
|
{"role": "user", "content": d[prompt_key]},
|
||||||
{"role": "assistant", "content": d[rejected_key]},
|
{"role": "assistant", "content": d[rejected_key]},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
for d in data
|
|
||||||
]
|
self._chosen_data.append(chosen_text)
|
||||||
|
self._rejected_data.append(rejected_text)
|
||||||
|
|
||||||
|
# Handle scores if they exist
|
||||||
|
if score_chosen_key in d and score_rejected_key in d:
|
||||||
|
chosen_score = float(d[score_chosen_key])
|
||||||
|
rejected_score = float(d[score_rejected_key])
|
||||||
|
|
||||||
|
# Normalize scores to [0, 1] range
|
||||||
|
score_diff = chosen_score - rejected_score
|
||||||
|
max_diff = max(abs(score_diff), 1.0) # Avoid division by zero
|
||||||
|
normalized_score = (score_diff / max_diff + 1) / 2
|
||||||
|
|
||||||
|
self._scores.append(normalized_score)
|
||||||
|
else:
|
||||||
|
# Default to binary preference (1.0) if no scores provided
|
||||||
|
self._scores.append(1.0)
|
||||||
|
|
||||||
def __getitem__(self, idx: int):
|
def __getitem__(self, idx: int):
|
||||||
return {
|
return {
|
||||||
"chosen": self._chosen_data[idx],
|
"chosen": self._chosen_data[idx],
|
||||||
"rejected": self._rejected_data[idx]
|
"rejected": self._rejected_data[idx],
|
||||||
|
"preference_score": self._scores[idx]
|
||||||
}
|
}
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
@ -45,12 +45,6 @@ class DPOTrainingArgs(TrainingArgs):
|
|||||||
"help": "Path to reference model weights. If None, uses the same model."
|
"help": "Path to reference model weights. If None, uses the same model."
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
train_bias_only: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={
|
|
||||||
"help": "Whether to train only bias terms in the model."
|
|
||||||
}
|
|
||||||
)
|
|
||||||
seed: int = field(
|
seed: int = field(
|
||||||
default=42,
|
default=42,
|
||||||
metadata={
|
metadata={
|
||||||
|
@ -30,19 +30,19 @@ def orpo_loss(
|
|||||||
rejected_masks: mx.array,
|
rejected_masks: mx.array,
|
||||||
chosen_rewards: mx.array,
|
chosen_rewards: mx.array,
|
||||||
rejected_rewards: mx.array,
|
rejected_rewards: mx.array,
|
||||||
beta: float,
|
beta: float = 0.1,
|
||||||
reward_scaling: float = 1.0,
|
reward_scaling: float = 1.0,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Calculate ORPO loss using pre-computed rewards.
|
Calculate ORPO loss using pre-computed rewards that incorporate preference scores.
|
||||||
Args:
|
Args:
|
||||||
model: Policy model
|
model: Policy model
|
||||||
chosen: Chosen sequence tokens
|
chosen: Chosen sequence tokens
|
||||||
rejected: Rejected sequence tokens
|
rejected: Rejected sequence tokens
|
||||||
chosen_masks: Attention masks for chosen sequences
|
chosen_masks: Attention masks for chosen sequences
|
||||||
rejected_masks: Attention masks for rejected sequences
|
rejected_masks: Attention masks for rejected sequences
|
||||||
chosen_rewards: Pre-computed rewards for chosen sequences
|
chosen_rewards: Rewards for chosen sequences (derived from preference scores)
|
||||||
rejected_rewards: Pre-computed rewards for rejected sequences
|
rejected_rewards: Rewards for rejected sequences (derived from preference scores)
|
||||||
beta: Temperature parameter
|
beta: Temperature parameter
|
||||||
reward_scaling: Scaling factor for rewards
|
reward_scaling: Scaling factor for rewards
|
||||||
Returns:
|
Returns:
|
||||||
@ -65,7 +65,7 @@ def orpo_loss(
|
|||||||
chosen_rewards = chosen_rewards * reward_scaling
|
chosen_rewards = chosen_rewards * reward_scaling
|
||||||
rejected_rewards = rejected_rewards * reward_scaling
|
rejected_rewards = rejected_rewards * reward_scaling
|
||||||
|
|
||||||
# ORPO uses the reward difference directly
|
# Calculate reward difference
|
||||||
reward_diff = chosen_rewards - rejected_rewards
|
reward_diff = chosen_rewards - rejected_rewards
|
||||||
|
|
||||||
# Calculate ORPO loss using logistic function
|
# Calculate ORPO loss using logistic function
|
||||||
@ -140,7 +140,7 @@ def evaluate_orpo(
|
|||||||
|
|
||||||
def iterate_orpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
def iterate_orpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
||||||
"""
|
"""
|
||||||
Modified batch iterator for ORPO that includes pre-computed rewards.
|
Modified batch iterator for ORPO that includes preference scores.
|
||||||
Works with pre-tokenized input data.
|
Works with pre-tokenized input data.
|
||||||
"""
|
"""
|
||||||
# Sort pairs by length of the chosen response
|
# Sort pairs by length of the chosen response
|
||||||
@ -186,9 +186,14 @@ def iterate_orpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=F
|
|||||||
chosen_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32)
|
chosen_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32)
|
||||||
rejected_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32)
|
rejected_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32)
|
||||||
|
|
||||||
# Always use binary rewards
|
# Get preference scores and convert to rewards
|
||||||
chosen_rewards = np.ones((batch_size // step,), np.float32)
|
preference_scores = np.array([x.get('preference_score', 1.0) for x in batch], np.float32)
|
||||||
rejected_rewards = np.zeros((batch_size // step,), np.float32)
|
# Convert preference scores to chosen/rejected rewards
|
||||||
|
# When preference_score is 1.0, chosen_reward=1.0, rejected_reward=0.0
|
||||||
|
# When preference_score is 0.0, chosen_reward=0.0, rejected_reward=1.0
|
||||||
|
# When preference_score is 0.5, both rewards are 0.5
|
||||||
|
chosen_rewards = preference_scores
|
||||||
|
rejected_rewards = 1.0 - preference_scores
|
||||||
|
|
||||||
for j in range(batch_size // step):
|
for j in range(batch_size // step):
|
||||||
# Use pre-tokenized sequences directly
|
# Use pre-tokenized sequences directly
|
||||||
@ -200,9 +205,14 @@ def iterate_orpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=F
|
|||||||
rejected_arr[j, :rejected_length] = batch[j]['rejected'][:rejected_length]
|
rejected_arr[j, :rejected_length] = batch[j]['rejected'][:rejected_length]
|
||||||
rejected_masks[j, :rejected_length] = 1.0
|
rejected_masks[j, :rejected_length] = 1.0
|
||||||
|
|
||||||
yield (mx.array(chosen_arr), mx.array(rejected_arr),
|
yield (
|
||||||
mx.array(chosen_masks), mx.array(rejected_masks),
|
mx.array(chosen_arr),
|
||||||
mx.array(chosen_rewards), mx.array(rejected_rewards))
|
mx.array(rejected_arr),
|
||||||
|
mx.array(chosen_masks),
|
||||||
|
mx.array(rejected_masks),
|
||||||
|
mx.array(chosen_rewards),
|
||||||
|
mx.array(rejected_rewards)
|
||||||
|
)
|
||||||
|
|
||||||
if not train:
|
if not train:
|
||||||
break
|
break
|
||||||
|
Loading…
Reference in New Issue
Block a user