mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-12 12:16:44 +08:00
nits and quality of life improvements
This commit is contained in:
parent
531c3345c6
commit
86b315fdf9
@ -19,7 +19,7 @@ LoRA (QLoRA).[^qlora] LoRA fine-tuning works with the following model families:
|
|||||||
|
|
||||||
- [Run](#Run)
|
- [Run](#Run)
|
||||||
- [Fine-tune](#Fine-tune)
|
- [Fine-tune](#Fine-tune)
|
||||||
- [DPO Training](#DPO Training)
|
- [DPO-Training](#DPOTraining)
|
||||||
- [Evaluate](#Evaluate)
|
- [Evaluate](#Evaluate)
|
||||||
- [Generate](#Generate)
|
- [Generate](#Generate)
|
||||||
- [Fuse](#Fuse)
|
- [Fuse](#Fuse)
|
||||||
@ -105,6 +105,12 @@ For DPO training, the data should be in JSONL format with the following structur
|
|||||||
{"prompt": "User prompt", "chosen": "Preferred response", "rejected": "Less preferred response"}
|
{"prompt": "User prompt", "chosen": "Preferred response", "rejected": "Less preferred response"}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
if the Prompt template accept a system message, you can extend the Dataset with a additional "system" field.
|
||||||
|
|
||||||
|
```jsonl
|
||||||
|
{"system": "You are a helpfull assistant", "prompt": "User prompt", "chosen": "Preferred response", "rejected": "Less preferred response"}
|
||||||
|
```
|
||||||
|
|
||||||
### Evaluate
|
### Evaluate
|
||||||
|
|
||||||
To compute test set perplexity use:
|
To compute test set perplexity use:
|
||||||
|
@ -242,8 +242,7 @@ def train_model(
|
|||||||
loss_type=args.dpo_loss_type,
|
loss_type=args.dpo_loss_type,
|
||||||
is_reference_free=args.is_reference_free,
|
is_reference_free=args.is_reference_free,
|
||||||
delta=args.delta,
|
delta=args.delta,
|
||||||
reference_model_path=args.reference_model_path,
|
reference_model_path=args.reference_model_path
|
||||||
train_bias_only=args.train_bias_only,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.reference_model_path:
|
if args.reference_model_path:
|
||||||
|
@ -12,54 +12,25 @@ class DPODataset:
|
|||||||
{"system": ..., "prompt": ..., "chosen": ..., "rejected": ...}
|
{"system": ..., "prompt": ..., "chosen": ..., "rejected": ...}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer,
|
||||||
self,
|
prompt_key: str = "prompt", chosen_key: str = "chosen",
|
||||||
data: List[Dict[str, str]],
|
rejected_key: str = "rejected", system_key: str = "system"):
|
||||||
tokenizer: PreTrainedTokenizer,
|
|
||||||
prompt_key: str = "prompt",
|
|
||||||
chosen_key: str = "chosen",
|
|
||||||
rejected_key: str = "rejected",
|
|
||||||
system_key: str = "system",
|
|
||||||
):
|
|
||||||
|
|
||||||
self._chosen_data = []
|
self._chosen_data = []
|
||||||
self._rejected_data = []
|
self._rejected_data = []
|
||||||
self._scores = []
|
|
||||||
|
|
||||||
for d in data:
|
for d in data:
|
||||||
if system_key and system_key in d:
|
messages = (
|
||||||
chosen = tokenizer.apply_chat_template(
|
[{"role": "system", "content": d[system_key]}] if system_key and system_key in d else []
|
||||||
[
|
)
|
||||||
{"role": "system", "content": d[system_key]},
|
messages.append({"role": "user", "content": d[prompt_key]})
|
||||||
{"role": "user", "content": d[prompt_key]},
|
|
||||||
{"role": "assistant", "content": d[chosen_key]},
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
rejected = tokenizer.apply_chat_template(
|
# Apply template once for each response type
|
||||||
[
|
base_messages = messages.copy()
|
||||||
{"role": "system", "content": d[system_key]},
|
chosen_messages = base_messages + [{"role": "assistant", "content": d[chosen_key]}]
|
||||||
{"role": "user", "content": d[prompt_key]},
|
rejected_messages = base_messages + [{"role": "assistant", "content": d[rejected_key]}]
|
||||||
{"role": "assistant", "content": d[rejected_key]},
|
|
||||||
],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
chosen = tokenizer.apply_chat_template(
|
|
||||||
[
|
|
||||||
{"role": "user", "content": d[prompt_key]},
|
|
||||||
{"role": "assistant", "content": d[chosen_key]},
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
rejected = tokenizer.apply_chat_template(
|
self._chosen_data.append(tokenizer.apply_chat_template(chosen_messages))
|
||||||
[
|
self._rejected_data.append(tokenizer.apply_chat_template(rejected_messages))
|
||||||
{"role": "user", "content": d[prompt_key]},
|
|
||||||
{"role": "assistant", "content": d[rejected_key]},
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
self._chosen_data.append(chosen)
|
|
||||||
self._rejected_data.append(rejected)
|
|
||||||
|
|
||||||
def __getitem__(self, idx: int):
|
def __getitem__(self, idx: int):
|
||||||
return {
|
return {
|
||||||
|
@ -46,18 +46,6 @@ class DPOTrainingArgs(TrainingArgs):
|
|||||||
"help": "Path to reference model weights. If None, uses the same model."
|
"help": "Path to reference model weights. If None, uses the same model."
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
train_bias_only: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={
|
|
||||||
"help": "Whether to train only bias terms in the model."
|
|
||||||
}
|
|
||||||
)
|
|
||||||
seed: int = field(
|
|
||||||
default=42,
|
|
||||||
metadata={
|
|
||||||
"help": "Random seed for reproducibility."
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def dpo_loss(
|
def dpo_loss(
|
||||||
@ -72,15 +60,6 @@ def dpo_loss(
|
|||||||
loss_type: str = "sigmoid",
|
loss_type: str = "sigmoid",
|
||||||
is_reference_free: bool = False
|
is_reference_free: bool = False
|
||||||
):
|
):
|
||||||
"""
|
|
||||||
Calculate loss for inputs.
|
|
||||||
Args:
|
|
||||||
inputs: Input tokens.
|
|
||||||
targets: Target tokens.
|
|
||||||
lengths: Lengths of inputs.
|
|
||||||
Returns:
|
|
||||||
Loss value.
|
|
||||||
"""
|
|
||||||
def make_predictions(model, x, mask):
|
def make_predictions(model, x, mask):
|
||||||
inputs = x[:, :-1]
|
inputs = x[:, :-1]
|
||||||
targets = x[:, 1:]
|
targets = x[:, 1:]
|
||||||
@ -121,7 +100,7 @@ def dpo_loss(
|
|||||||
|
|
||||||
logits = (policy_chosen_score - policy_rejected_score) - (reference_chosen_score - reference_rejected_score)
|
logits = (policy_chosen_score - policy_rejected_score) - (reference_chosen_score - reference_rejected_score)
|
||||||
|
|
||||||
if loss_type == "sigmoid":
|
if loss_type == "sigmoid": # From the og paper
|
||||||
losses = -nn.log_sigmoid(beta * logits)
|
losses = -nn.log_sigmoid(beta * logits)
|
||||||
elif loss_type == "hinge":
|
elif loss_type == "hinge":
|
||||||
losses = nn.relu(1 - beta * logits)
|
losses = nn.relu(1 - beta * logits)
|
||||||
@ -144,69 +123,45 @@ def dpo_loss(
|
|||||||
return loss, reward, num_tokens
|
return loss, reward, num_tokens
|
||||||
|
|
||||||
|
|
||||||
def iterate_dpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
def iterate_dpo_batches(dataset, batch_size, max_seq_length, train=False):
|
||||||
"""
|
|
||||||
Modified iterate_batches for DPO training that handles chosen and rejected samples.
|
|
||||||
"""
|
|
||||||
# Sort pairs by length of the chosen response
|
|
||||||
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]['chosen']))
|
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]['chosen']))
|
||||||
if len(dataset) < batch_size:
|
|
||||||
raise ValueError(
|
|
||||||
f"Dataset must have at least batch_size={batch_size}"
|
|
||||||
f" examples but only has {len(dataset)}."
|
|
||||||
)
|
|
||||||
|
|
||||||
step = mx.distributed.init().size()
|
step = mx.distributed.init().size()
|
||||||
if batch_size % step != 0:
|
if batch_size % step != 0:
|
||||||
raise ValueError("The batch size must be divisible by the number of workers")
|
raise ValueError("Batch size must be divisible by workers")
|
||||||
|
|
||||||
batch_idx = [
|
batch_idx = [idx[i:i+batch_size:step] for i in range(0, len(idx)-batch_size+1, batch_size)]
|
||||||
idx[i : i + batch_size : step]
|
|
||||||
for i in range(0, len(idx) - batch_size + 1, batch_size)
|
|
||||||
]
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
indices = np.random.permutation(len(batch_idx)) if train else range(len(batch_idx))
|
indices = np.random.permutation(len(batch_idx)) if train else range(len(batch_idx))
|
||||||
for i in indices:
|
for i in indices:
|
||||||
batch = [dataset[j] for j in batch_idx[i]]
|
batch = [dataset[j] for j in batch_idx[i]]
|
||||||
|
|
||||||
# Get lengths for chosen and rejected sequences
|
# Get and process lengths
|
||||||
chosen_lengths = [len(x['chosen']) for x in batch]
|
chosen_lengths = [len(x['chosen']) for x in batch]
|
||||||
rejected_lengths = [len(x['rejected']) for x in batch]
|
rejected_lengths = [len(x['rejected']) for x in batch]
|
||||||
max_length = max(max(chosen_lengths), max(rejected_lengths))
|
max_length = min(max(max(chosen_lengths), max(rejected_lengths)), max_seq_length)
|
||||||
|
|
||||||
if max_length > max_seq_length:
|
# Dynamic padding based on batch content
|
||||||
print(
|
max_length_in_batch = max_length
|
||||||
f"[WARNING] Some sequences are longer than {max_seq_length} tokens. "
|
|
||||||
f"The longest sequence {max_length} will be truncated to {max_seq_length}."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Pad to nearest multiple of 8
|
|
||||||
pad_to = 8
|
|
||||||
max_length_in_batch = pad_to * ((max_length + pad_to - 1) // pad_to)
|
|
||||||
max_length_in_batch = min(max_length_in_batch, max_seq_length)
|
|
||||||
|
|
||||||
# Create arrays for chosen and rejected sequences
|
|
||||||
chosen_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
|
chosen_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
|
||||||
rejected_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
|
rejected_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
|
||||||
|
|
||||||
# Create attention masks
|
|
||||||
chosen_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32)
|
chosen_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32)
|
||||||
rejected_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32)
|
rejected_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32)
|
||||||
|
|
||||||
for j in range(batch_size // step):
|
for j in range(batch_size // step):
|
||||||
# Process chosen sequence
|
|
||||||
chosen_length = min(chosen_lengths[j], max_seq_length)
|
chosen_length = min(chosen_lengths[j], max_seq_length)
|
||||||
chosen_arr[j, :chosen_length] = batch[j]['chosen'][:chosen_length]
|
|
||||||
chosen_masks[j, :chosen_length] = 1.0
|
|
||||||
|
|
||||||
# Process rejected sequence
|
|
||||||
rejected_length = min(rejected_lengths[j], max_seq_length)
|
rejected_length = min(rejected_lengths[j], max_seq_length)
|
||||||
|
|
||||||
|
chosen_arr[j, :chosen_length] = batch[j]['chosen'][:chosen_length]
|
||||||
rejected_arr[j, :rejected_length] = batch[j]['rejected'][:rejected_length]
|
rejected_arr[j, :rejected_length] = batch[j]['rejected'][:rejected_length]
|
||||||
|
|
||||||
|
chosen_masks[j, :chosen_length] = 1.0
|
||||||
rejected_masks[j, :rejected_length] = 1.0
|
rejected_masks[j, :rejected_length] = 1.0
|
||||||
|
|
||||||
yield (mx.array(chosen_arr), mx.array(rejected_arr),
|
yield mx.array(chosen_arr), mx.array(rejected_arr), mx.array(chosen_masks), mx.array(rejected_masks)
|
||||||
mx.array(chosen_masks), mx.array(rejected_masks))
|
|
||||||
|
|
||||||
if not train:
|
if not train:
|
||||||
break
|
break
|
||||||
@ -225,9 +180,6 @@ def evaluate_dpo(
|
|||||||
loss_fn: callable = dpo_loss,
|
loss_fn: callable = dpo_loss,
|
||||||
loss_type="sigmoid",
|
loss_type="sigmoid",
|
||||||
):
|
):
|
||||||
"""
|
|
||||||
Modified evaluate function for DPO training.
|
|
||||||
"""
|
|
||||||
all_losses = 0
|
all_losses = 0
|
||||||
all_rewards = mx.zeros((2,)) # [chosen_reward, rejected_reward]
|
all_rewards = mx.zeros((2,)) # [chosen_reward, rejected_reward]
|
||||||
ntokens = 0
|
ntokens = 0
|
||||||
@ -238,7 +190,6 @@ def evaluate_dpo(
|
|||||||
index_iterator,
|
index_iterator,
|
||||||
iterate_dpo_batches(
|
iterate_dpo_batches(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
tokenizer=tokenizer,
|
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
max_seq_length=max_seq_length,
|
max_seq_length=max_seq_length,
|
||||||
),
|
),
|
||||||
@ -279,9 +230,6 @@ def train_dpo(
|
|||||||
training_callback: TrainingCallback = None,
|
training_callback: TrainingCallback = None,
|
||||||
loss_type="sigmoid",
|
loss_type="sigmoid",
|
||||||
):
|
):
|
||||||
"""
|
|
||||||
Modified training function for DPO.
|
|
||||||
"""
|
|
||||||
print(f"Starting DPO training..., iters: {args.iters}")
|
print(f"Starting DPO training..., iters: {args.iters}")
|
||||||
world = mx.distributed.init()
|
world = mx.distributed.init()
|
||||||
world_size = world.size()
|
world_size = world.size()
|
||||||
@ -345,7 +293,6 @@ def train_dpo(
|
|||||||
range(1, args.iters + 1),
|
range(1, args.iters + 1),
|
||||||
iterate_dpo_batches(
|
iterate_dpo_batches(
|
||||||
dataset=train_dataset,
|
dataset=train_dataset,
|
||||||
tokenizer=tokenizer,
|
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
max_seq_length=args.max_seq_length,
|
max_seq_length=args.max_seq_length,
|
||||||
train=True,
|
train=True,
|
||||||
|
Loading…
Reference in New Issue
Block a user