nits and quality of life improvements

This commit is contained in:
Goekdeniz-Guelmez 2025-01-24 22:40:27 +01:00
parent 531c3345c6
commit 86b315fdf9
4 changed files with 43 additions and 120 deletions

View File

@ -19,7 +19,7 @@ LoRA (QLoRA).[^qlora] LoRA fine-tuning works with the following model families:
- [Run](#Run)
- [Fine-tune](#Fine-tune)
- [DPO Training](#DPO Training)
- [DPO-Training](#DPOTraining)
- [Evaluate](#Evaluate)
- [Generate](#Generate)
- [Fuse](#Fuse)
@ -105,6 +105,12 @@ For DPO training, the data should be in JSONL format with the following structur
{"prompt": "User prompt", "chosen": "Preferred response", "rejected": "Less preferred response"}
```
if the Prompt template accept a system message, you can extend the Dataset with a additional "system" field.
```jsonl
{"system": "You are a helpfull assistant", "prompt": "User prompt", "chosen": "Preferred response", "rejected": "Less preferred response"}
```
### Evaluate
To compute test set perplexity use:

View File

@ -242,8 +242,7 @@ def train_model(
loss_type=args.dpo_loss_type,
is_reference_free=args.is_reference_free,
delta=args.delta,
reference_model_path=args.reference_model_path,
train_bias_only=args.train_bias_only,
reference_model_path=args.reference_model_path
)
if args.reference_model_path:

View File

@ -12,54 +12,25 @@ class DPODataset:
{"system": ..., "prompt": ..., "chosen": ..., "rejected": ...}
"""
def __init__(
self,
data: List[Dict[str, str]],
tokenizer: PreTrainedTokenizer,
prompt_key: str = "prompt",
chosen_key: str = "chosen",
rejected_key: str = "rejected",
system_key: str = "system",
):
def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer,
prompt_key: str = "prompt", chosen_key: str = "chosen",
rejected_key: str = "rejected", system_key: str = "system"):
self._chosen_data = []
self._rejected_data = []
self._scores = []
for d in data:
if system_key and system_key in d:
chosen = tokenizer.apply_chat_template(
[
{"role": "system", "content": d[system_key]},
{"role": "user", "content": d[prompt_key]},
{"role": "assistant", "content": d[chosen_key]},
]
)
messages = (
[{"role": "system", "content": d[system_key]}] if system_key and system_key in d else []
)
messages.append({"role": "user", "content": d[prompt_key]})
rejected = tokenizer.apply_chat_template(
[
{"role": "system", "content": d[system_key]},
{"role": "user", "content": d[prompt_key]},
{"role": "assistant", "content": d[rejected_key]},
],
)
else:
chosen = tokenizer.apply_chat_template(
[
{"role": "user", "content": d[prompt_key]},
{"role": "assistant", "content": d[chosen_key]},
]
)
rejected = tokenizer.apply_chat_template(
[
{"role": "user", "content": d[prompt_key]},
{"role": "assistant", "content": d[rejected_key]},
],
)
self._chosen_data.append(chosen)
self._rejected_data.append(rejected)
# Apply template once for each response type
base_messages = messages.copy()
chosen_messages = base_messages + [{"role": "assistant", "content": d[chosen_key]}]
rejected_messages = base_messages + [{"role": "assistant", "content": d[rejected_key]}]
self._chosen_data.append(tokenizer.apply_chat_template(chosen_messages))
self._rejected_data.append(tokenizer.apply_chat_template(rejected_messages))
def __getitem__(self, idx: int):
return {

View File

@ -46,18 +46,6 @@ class DPOTrainingArgs(TrainingArgs):
"help": "Path to reference model weights. If None, uses the same model."
}
)
train_bias_only: bool = field(
default=False,
metadata={
"help": "Whether to train only bias terms in the model."
}
)
seed: int = field(
default=42,
metadata={
"help": "Random seed for reproducibility."
}
)
def dpo_loss(
@ -72,22 +60,13 @@ def dpo_loss(
loss_type: str = "sigmoid",
is_reference_free: bool = False
):
"""
Calculate loss for inputs.
Args:
inputs: Input tokens.
targets: Target tokens.
lengths: Lengths of inputs.
Returns:
Loss value.
"""
def make_predictions(model, x, mask):
inputs = x[:, :-1]
targets = x[:, 1:]
logits = model(inputs)
logits = logits.astype(mx.float32)
return -nn.losses.cross_entropy(logits, targets) * mask[:, :-1]
num_chosen_tokens = chosen_masks.sum(-1)
@ -121,7 +100,7 @@ def dpo_loss(
logits = (policy_chosen_score - policy_rejected_score) - (reference_chosen_score - reference_rejected_score)
if loss_type == "sigmoid":
if loss_type == "sigmoid": # From the og paper
losses = -nn.log_sigmoid(beta * logits)
elif loss_type == "hinge":
losses = nn.relu(1 - beta * logits)
@ -144,70 +123,46 @@ def dpo_loss(
return loss, reward, num_tokens
def iterate_dpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
"""
Modified iterate_batches for DPO training that handles chosen and rejected samples.
"""
# Sort pairs by length of the chosen response
def iterate_dpo_batches(dataset, batch_size, max_seq_length, train=False):
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]['chosen']))
if len(dataset) < batch_size:
raise ValueError(
f"Dataset must have at least batch_size={batch_size}"
f" examples but only has {len(dataset)}."
)
step = mx.distributed.init().size()
if batch_size % step != 0:
raise ValueError("The batch size must be divisible by the number of workers")
batch_idx = [
idx[i : i + batch_size : step]
for i in range(0, len(idx) - batch_size + 1, batch_size)
]
raise ValueError("Batch size must be divisible by workers")
batch_idx = [idx[i:i+batch_size:step] for i in range(0, len(idx)-batch_size+1, batch_size)]
while True:
indices = np.random.permutation(len(batch_idx)) if train else range(len(batch_idx))
for i in indices:
batch = [dataset[j] for j in batch_idx[i]]
# Get lengths for chosen and rejected sequences
# Get and process lengths
chosen_lengths = [len(x['chosen']) for x in batch]
rejected_lengths = [len(x['rejected']) for x in batch]
max_length = max(max(chosen_lengths), max(rejected_lengths))
max_length = min(max(max(chosen_lengths), max(rejected_lengths)), max_seq_length)
# Dynamic padding based on batch content
max_length_in_batch = max_length
if max_length > max_seq_length:
print(
f"[WARNING] Some sequences are longer than {max_seq_length} tokens. "
f"The longest sequence {max_length} will be truncated to {max_seq_length}."
)
# Pad to nearest multiple of 8
pad_to = 8
max_length_in_batch = pad_to * ((max_length + pad_to - 1) // pad_to)
max_length_in_batch = min(max_length_in_batch, max_seq_length)
# Create arrays for chosen and rejected sequences
chosen_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
rejected_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
# Create attention masks
chosen_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32)
rejected_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32)
for j in range(batch_size // step):
# Process chosen sequence
chosen_length = min(chosen_lengths[j], max_seq_length)
chosen_arr[j, :chosen_length] = batch[j]['chosen'][:chosen_length]
chosen_masks[j, :chosen_length] = 1.0
# Process rejected sequence
rejected_length = min(rejected_lengths[j], max_seq_length)
chosen_arr[j, :chosen_length] = batch[j]['chosen'][:chosen_length]
rejected_arr[j, :rejected_length] = batch[j]['rejected'][:rejected_length]
chosen_masks[j, :chosen_length] = 1.0
rejected_masks[j, :rejected_length] = 1.0
yield (mx.array(chosen_arr), mx.array(rejected_arr),
mx.array(chosen_masks), mx.array(rejected_masks))
yield mx.array(chosen_arr), mx.array(rejected_arr), mx.array(chosen_masks), mx.array(rejected_masks)
if not train:
break
@ -225,9 +180,6 @@ def evaluate_dpo(
loss_fn: callable = dpo_loss,
loss_type="sigmoid",
):
"""
Modified evaluate function for DPO training.
"""
all_losses = 0
all_rewards = mx.zeros((2,)) # [chosen_reward, rejected_reward]
ntokens = 0
@ -238,7 +190,6 @@ def evaluate_dpo(
index_iterator,
iterate_dpo_batches(
dataset=dataset,
tokenizer=tokenizer,
batch_size=batch_size,
max_seq_length=max_seq_length,
),
@ -279,9 +230,6 @@ def train_dpo(
training_callback: TrainingCallback = None,
loss_type="sigmoid",
):
"""
Modified training function for DPO.
"""
print(f"Starting DPO training..., iters: {args.iters}")
world = mx.distributed.init()
world_size = world.size()
@ -345,7 +293,6 @@ def train_dpo(
range(1, args.iters + 1),
iterate_dpo_batches(
dataset=train_dataset,
tokenizer=tokenizer,
batch_size=args.batch_size,
max_seq_length=args.max_seq_length,
train=True,