mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-09 18:36:38 +08:00
nits and quality of life improvements
This commit is contained in:
parent
531c3345c6
commit
86b315fdf9
@ -19,7 +19,7 @@ LoRA (QLoRA).[^qlora] LoRA fine-tuning works with the following model families:
|
||||
|
||||
- [Run](#Run)
|
||||
- [Fine-tune](#Fine-tune)
|
||||
- [DPO Training](#DPO Training)
|
||||
- [DPO-Training](#DPOTraining)
|
||||
- [Evaluate](#Evaluate)
|
||||
- [Generate](#Generate)
|
||||
- [Fuse](#Fuse)
|
||||
@ -105,6 +105,12 @@ For DPO training, the data should be in JSONL format with the following structur
|
||||
{"prompt": "User prompt", "chosen": "Preferred response", "rejected": "Less preferred response"}
|
||||
```
|
||||
|
||||
if the Prompt template accept a system message, you can extend the Dataset with a additional "system" field.
|
||||
|
||||
```jsonl
|
||||
{"system": "You are a helpfull assistant", "prompt": "User prompt", "chosen": "Preferred response", "rejected": "Less preferred response"}
|
||||
```
|
||||
|
||||
### Evaluate
|
||||
|
||||
To compute test set perplexity use:
|
||||
|
@ -242,8 +242,7 @@ def train_model(
|
||||
loss_type=args.dpo_loss_type,
|
||||
is_reference_free=args.is_reference_free,
|
||||
delta=args.delta,
|
||||
reference_model_path=args.reference_model_path,
|
||||
train_bias_only=args.train_bias_only,
|
||||
reference_model_path=args.reference_model_path
|
||||
)
|
||||
|
||||
if args.reference_model_path:
|
||||
|
@ -12,54 +12,25 @@ class DPODataset:
|
||||
{"system": ..., "prompt": ..., "chosen": ..., "rejected": ...}
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data: List[Dict[str, str]],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
prompt_key: str = "prompt",
|
||||
chosen_key: str = "chosen",
|
||||
rejected_key: str = "rejected",
|
||||
system_key: str = "system",
|
||||
):
|
||||
|
||||
def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer,
|
||||
prompt_key: str = "prompt", chosen_key: str = "chosen",
|
||||
rejected_key: str = "rejected", system_key: str = "system"):
|
||||
self._chosen_data = []
|
||||
self._rejected_data = []
|
||||
self._scores = []
|
||||
|
||||
for d in data:
|
||||
if system_key and system_key in d:
|
||||
chosen = tokenizer.apply_chat_template(
|
||||
[
|
||||
{"role": "system", "content": d[system_key]},
|
||||
{"role": "user", "content": d[prompt_key]},
|
||||
{"role": "assistant", "content": d[chosen_key]},
|
||||
]
|
||||
)
|
||||
messages = (
|
||||
[{"role": "system", "content": d[system_key]}] if system_key and system_key in d else []
|
||||
)
|
||||
messages.append({"role": "user", "content": d[prompt_key]})
|
||||
|
||||
rejected = tokenizer.apply_chat_template(
|
||||
[
|
||||
{"role": "system", "content": d[system_key]},
|
||||
{"role": "user", "content": d[prompt_key]},
|
||||
{"role": "assistant", "content": d[rejected_key]},
|
||||
],
|
||||
)
|
||||
else:
|
||||
chosen = tokenizer.apply_chat_template(
|
||||
[
|
||||
{"role": "user", "content": d[prompt_key]},
|
||||
{"role": "assistant", "content": d[chosen_key]},
|
||||
]
|
||||
)
|
||||
|
||||
rejected = tokenizer.apply_chat_template(
|
||||
[
|
||||
{"role": "user", "content": d[prompt_key]},
|
||||
{"role": "assistant", "content": d[rejected_key]},
|
||||
],
|
||||
)
|
||||
|
||||
self._chosen_data.append(chosen)
|
||||
self._rejected_data.append(rejected)
|
||||
# Apply template once for each response type
|
||||
base_messages = messages.copy()
|
||||
chosen_messages = base_messages + [{"role": "assistant", "content": d[chosen_key]}]
|
||||
rejected_messages = base_messages + [{"role": "assistant", "content": d[rejected_key]}]
|
||||
|
||||
self._chosen_data.append(tokenizer.apply_chat_template(chosen_messages))
|
||||
self._rejected_data.append(tokenizer.apply_chat_template(rejected_messages))
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
return {
|
||||
|
@ -46,18 +46,6 @@ class DPOTrainingArgs(TrainingArgs):
|
||||
"help": "Path to reference model weights. If None, uses the same model."
|
||||
}
|
||||
)
|
||||
train_bias_only: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to train only bias terms in the model."
|
||||
}
|
||||
)
|
||||
seed: int = field(
|
||||
default=42,
|
||||
metadata={
|
||||
"help": "Random seed for reproducibility."
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def dpo_loss(
|
||||
@ -72,22 +60,13 @@ def dpo_loss(
|
||||
loss_type: str = "sigmoid",
|
||||
is_reference_free: bool = False
|
||||
):
|
||||
"""
|
||||
Calculate loss for inputs.
|
||||
Args:
|
||||
inputs: Input tokens.
|
||||
targets: Target tokens.
|
||||
lengths: Lengths of inputs.
|
||||
Returns:
|
||||
Loss value.
|
||||
"""
|
||||
def make_predictions(model, x, mask):
|
||||
inputs = x[:, :-1]
|
||||
targets = x[:, 1:]
|
||||
|
||||
logits = model(inputs)
|
||||
logits = logits.astype(mx.float32)
|
||||
|
||||
|
||||
return -nn.losses.cross_entropy(logits, targets) * mask[:, :-1]
|
||||
|
||||
num_chosen_tokens = chosen_masks.sum(-1)
|
||||
@ -121,7 +100,7 @@ def dpo_loss(
|
||||
|
||||
logits = (policy_chosen_score - policy_rejected_score) - (reference_chosen_score - reference_rejected_score)
|
||||
|
||||
if loss_type == "sigmoid":
|
||||
if loss_type == "sigmoid": # From the og paper
|
||||
losses = -nn.log_sigmoid(beta * logits)
|
||||
elif loss_type == "hinge":
|
||||
losses = nn.relu(1 - beta * logits)
|
||||
@ -144,70 +123,46 @@ def dpo_loss(
|
||||
return loss, reward, num_tokens
|
||||
|
||||
|
||||
def iterate_dpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
||||
"""
|
||||
Modified iterate_batches for DPO training that handles chosen and rejected samples.
|
||||
"""
|
||||
# Sort pairs by length of the chosen response
|
||||
def iterate_dpo_batches(dataset, batch_size, max_seq_length, train=False):
|
||||
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]['chosen']))
|
||||
if len(dataset) < batch_size:
|
||||
raise ValueError(
|
||||
f"Dataset must have at least batch_size={batch_size}"
|
||||
f" examples but only has {len(dataset)}."
|
||||
)
|
||||
|
||||
|
||||
step = mx.distributed.init().size()
|
||||
if batch_size % step != 0:
|
||||
raise ValueError("The batch size must be divisible by the number of workers")
|
||||
|
||||
batch_idx = [
|
||||
idx[i : i + batch_size : step]
|
||||
for i in range(0, len(idx) - batch_size + 1, batch_size)
|
||||
]
|
||||
|
||||
raise ValueError("Batch size must be divisible by workers")
|
||||
|
||||
batch_idx = [idx[i:i+batch_size:step] for i in range(0, len(idx)-batch_size+1, batch_size)]
|
||||
|
||||
while True:
|
||||
indices = np.random.permutation(len(batch_idx)) if train else range(len(batch_idx))
|
||||
for i in indices:
|
||||
batch = [dataset[j] for j in batch_idx[i]]
|
||||
|
||||
# Get lengths for chosen and rejected sequences
|
||||
# Get and process lengths
|
||||
chosen_lengths = [len(x['chosen']) for x in batch]
|
||||
rejected_lengths = [len(x['rejected']) for x in batch]
|
||||
max_length = max(max(chosen_lengths), max(rejected_lengths))
|
||||
max_length = min(max(max(chosen_lengths), max(rejected_lengths)), max_seq_length)
|
||||
|
||||
# Dynamic padding based on batch content
|
||||
max_length_in_batch = max_length
|
||||
|
||||
if max_length > max_seq_length:
|
||||
print(
|
||||
f"[WARNING] Some sequences are longer than {max_seq_length} tokens. "
|
||||
f"The longest sequence {max_length} will be truncated to {max_seq_length}."
|
||||
)
|
||||
|
||||
# Pad to nearest multiple of 8
|
||||
pad_to = 8
|
||||
max_length_in_batch = pad_to * ((max_length + pad_to - 1) // pad_to)
|
||||
max_length_in_batch = min(max_length_in_batch, max_seq_length)
|
||||
|
||||
# Create arrays for chosen and rejected sequences
|
||||
chosen_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
|
||||
rejected_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
|
||||
|
||||
# Create attention masks
|
||||
chosen_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32)
|
||||
rejected_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32)
|
||||
|
||||
|
||||
for j in range(batch_size // step):
|
||||
# Process chosen sequence
|
||||
chosen_length = min(chosen_lengths[j], max_seq_length)
|
||||
chosen_arr[j, :chosen_length] = batch[j]['chosen'][:chosen_length]
|
||||
chosen_masks[j, :chosen_length] = 1.0
|
||||
|
||||
# Process rejected sequence
|
||||
rejected_length = min(rejected_lengths[j], max_seq_length)
|
||||
|
||||
chosen_arr[j, :chosen_length] = batch[j]['chosen'][:chosen_length]
|
||||
rejected_arr[j, :rejected_length] = batch[j]['rejected'][:rejected_length]
|
||||
|
||||
chosen_masks[j, :chosen_length] = 1.0
|
||||
rejected_masks[j, :rejected_length] = 1.0
|
||||
|
||||
yield (mx.array(chosen_arr), mx.array(rejected_arr),
|
||||
mx.array(chosen_masks), mx.array(rejected_masks))
|
||||
|
||||
|
||||
yield mx.array(chosen_arr), mx.array(rejected_arr), mx.array(chosen_masks), mx.array(rejected_masks)
|
||||
|
||||
if not train:
|
||||
break
|
||||
|
||||
@ -225,9 +180,6 @@ def evaluate_dpo(
|
||||
loss_fn: callable = dpo_loss,
|
||||
loss_type="sigmoid",
|
||||
):
|
||||
"""
|
||||
Modified evaluate function for DPO training.
|
||||
"""
|
||||
all_losses = 0
|
||||
all_rewards = mx.zeros((2,)) # [chosen_reward, rejected_reward]
|
||||
ntokens = 0
|
||||
@ -238,7 +190,6 @@ def evaluate_dpo(
|
||||
index_iterator,
|
||||
iterate_dpo_batches(
|
||||
dataset=dataset,
|
||||
tokenizer=tokenizer,
|
||||
batch_size=batch_size,
|
||||
max_seq_length=max_seq_length,
|
||||
),
|
||||
@ -279,9 +230,6 @@ def train_dpo(
|
||||
training_callback: TrainingCallback = None,
|
||||
loss_type="sigmoid",
|
||||
):
|
||||
"""
|
||||
Modified training function for DPO.
|
||||
"""
|
||||
print(f"Starting DPO training..., iters: {args.iters}")
|
||||
world = mx.distributed.init()
|
||||
world_size = world.size()
|
||||
@ -345,7 +293,6 @@ def train_dpo(
|
||||
range(1, args.iters + 1),
|
||||
iterate_dpo_batches(
|
||||
dataset=train_dataset,
|
||||
tokenizer=tokenizer,
|
||||
batch_size=args.batch_size,
|
||||
max_seq_length=args.max_seq_length,
|
||||
train=True,
|
||||
|
Loading…
Reference in New Issue
Block a user