From 86b315fdf94f729349b2e740387211e69755d2fb Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Fri, 24 Jan 2025 22:40:27 +0100 Subject: [PATCH] nits and quality of life improvements --- llms/mlx_lm/LORA.md | 8 ++- llms/mlx_lm/lora.py | 3 +- llms/mlx_lm/tuner/datasets.py | 57 +++++-------------- llms/mlx_lm/tuner/dpo_trainer.py | 95 +++++++------------------------- 4 files changed, 43 insertions(+), 120 deletions(-) diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md index 3ae78a01..6dd8197d 100644 --- a/llms/mlx_lm/LORA.md +++ b/llms/mlx_lm/LORA.md @@ -19,7 +19,7 @@ LoRA (QLoRA).[^qlora] LoRA fine-tuning works with the following model families: - [Run](#Run) - [Fine-tune](#Fine-tune) - - [DPO Training](#DPO Training) + - [DPO-Training](#DPOTraining) - [Evaluate](#Evaluate) - [Generate](#Generate) - [Fuse](#Fuse) @@ -105,6 +105,12 @@ For DPO training, the data should be in JSONL format with the following structur {"prompt": "User prompt", "chosen": "Preferred response", "rejected": "Less preferred response"} ``` +if the Prompt template accept a system message, you can extend the Dataset with a additional "system" field. + +```jsonl +{"system": "You are a helpfull assistant", "prompt": "User prompt", "chosen": "Preferred response", "rejected": "Less preferred response"} +``` + ### Evaluate To compute test set perplexity use: diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 2e33bddc..dcf94bad 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -242,8 +242,7 @@ def train_model( loss_type=args.dpo_loss_type, is_reference_free=args.is_reference_free, delta=args.delta, - reference_model_path=args.reference_model_path, - train_bias_only=args.train_bias_only, + reference_model_path=args.reference_model_path ) if args.reference_model_path: diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 98f01ef3..e3d1b8fb 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -12,54 +12,25 @@ class DPODataset: {"system": ..., "prompt": ..., "chosen": ..., "rejected": ...} """ - def __init__( - self, - data: List[Dict[str, str]], - tokenizer: PreTrainedTokenizer, - prompt_key: str = "prompt", - chosen_key: str = "chosen", - rejected_key: str = "rejected", - system_key: str = "system", - ): - + def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer, + prompt_key: str = "prompt", chosen_key: str = "chosen", + rejected_key: str = "rejected", system_key: str = "system"): self._chosen_data = [] self._rejected_data = [] - self._scores = [] for d in data: - if system_key and system_key in d: - chosen = tokenizer.apply_chat_template( - [ - {"role": "system", "content": d[system_key]}, - {"role": "user", "content": d[prompt_key]}, - {"role": "assistant", "content": d[chosen_key]}, - ] - ) + messages = ( + [{"role": "system", "content": d[system_key]}] if system_key and system_key in d else [] + ) + messages.append({"role": "user", "content": d[prompt_key]}) - rejected = tokenizer.apply_chat_template( - [ - {"role": "system", "content": d[system_key]}, - {"role": "user", "content": d[prompt_key]}, - {"role": "assistant", "content": d[rejected_key]}, - ], - ) - else: - chosen = tokenizer.apply_chat_template( - [ - {"role": "user", "content": d[prompt_key]}, - {"role": "assistant", "content": d[chosen_key]}, - ] - ) - - rejected = tokenizer.apply_chat_template( - [ - {"role": "user", "content": d[prompt_key]}, - {"role": "assistant", "content": d[rejected_key]}, - ], - ) - - self._chosen_data.append(chosen) - self._rejected_data.append(rejected) + # Apply template once for each response type + base_messages = messages.copy() + chosen_messages = base_messages + [{"role": "assistant", "content": d[chosen_key]}] + rejected_messages = base_messages + [{"role": "assistant", "content": d[rejected_key]}] + + self._chosen_data.append(tokenizer.apply_chat_template(chosen_messages)) + self._rejected_data.append(tokenizer.apply_chat_template(rejected_messages)) def __getitem__(self, idx: int): return { diff --git a/llms/mlx_lm/tuner/dpo_trainer.py b/llms/mlx_lm/tuner/dpo_trainer.py index 142f8aea..22797c7e 100644 --- a/llms/mlx_lm/tuner/dpo_trainer.py +++ b/llms/mlx_lm/tuner/dpo_trainer.py @@ -46,18 +46,6 @@ class DPOTrainingArgs(TrainingArgs): "help": "Path to reference model weights. If None, uses the same model." } ) - train_bias_only: bool = field( - default=False, - metadata={ - "help": "Whether to train only bias terms in the model." - } - ) - seed: int = field( - default=42, - metadata={ - "help": "Random seed for reproducibility." - } - ) def dpo_loss( @@ -72,22 +60,13 @@ def dpo_loss( loss_type: str = "sigmoid", is_reference_free: bool = False ): - """ - Calculate loss for inputs. - Args: - inputs: Input tokens. - targets: Target tokens. - lengths: Lengths of inputs. - Returns: - Loss value. - """ def make_predictions(model, x, mask): inputs = x[:, :-1] targets = x[:, 1:] logits = model(inputs) logits = logits.astype(mx.float32) - + return -nn.losses.cross_entropy(logits, targets) * mask[:, :-1] num_chosen_tokens = chosen_masks.sum(-1) @@ -121,7 +100,7 @@ def dpo_loss( logits = (policy_chosen_score - policy_rejected_score) - (reference_chosen_score - reference_rejected_score) - if loss_type == "sigmoid": + if loss_type == "sigmoid": # From the og paper losses = -nn.log_sigmoid(beta * logits) elif loss_type == "hinge": losses = nn.relu(1 - beta * logits) @@ -144,70 +123,46 @@ def dpo_loss( return loss, reward, num_tokens -def iterate_dpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): - """ - Modified iterate_batches for DPO training that handles chosen and rejected samples. - """ - # Sort pairs by length of the chosen response +def iterate_dpo_batches(dataset, batch_size, max_seq_length, train=False): idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]['chosen'])) - if len(dataset) < batch_size: - raise ValueError( - f"Dataset must have at least batch_size={batch_size}" - f" examples but only has {len(dataset)}." - ) - + step = mx.distributed.init().size() if batch_size % step != 0: - raise ValueError("The batch size must be divisible by the number of workers") - - batch_idx = [ - idx[i : i + batch_size : step] - for i in range(0, len(idx) - batch_size + 1, batch_size) - ] - + raise ValueError("Batch size must be divisible by workers") + + batch_idx = [idx[i:i+batch_size:step] for i in range(0, len(idx)-batch_size+1, batch_size)] + while True: indices = np.random.permutation(len(batch_idx)) if train else range(len(batch_idx)) for i in indices: batch = [dataset[j] for j in batch_idx[i]] - # Get lengths for chosen and rejected sequences + # Get and process lengths chosen_lengths = [len(x['chosen']) for x in batch] rejected_lengths = [len(x['rejected']) for x in batch] - max_length = max(max(chosen_lengths), max(rejected_lengths)) + max_length = min(max(max(chosen_lengths), max(rejected_lengths)), max_seq_length) + + # Dynamic padding based on batch content + max_length_in_batch = max_length - if max_length > max_seq_length: - print( - f"[WARNING] Some sequences are longer than {max_seq_length} tokens. " - f"The longest sequence {max_length} will be truncated to {max_seq_length}." - ) - - # Pad to nearest multiple of 8 - pad_to = 8 - max_length_in_batch = pad_to * ((max_length + pad_to - 1) // pad_to) - max_length_in_batch = min(max_length_in_batch, max_seq_length) - - # Create arrays for chosen and rejected sequences chosen_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32) rejected_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32) - # Create attention masks chosen_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32) rejected_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32) - + for j in range(batch_size // step): - # Process chosen sequence chosen_length = min(chosen_lengths[j], max_seq_length) - chosen_arr[j, :chosen_length] = batch[j]['chosen'][:chosen_length] - chosen_masks[j, :chosen_length] = 1.0 - - # Process rejected sequence rejected_length = min(rejected_lengths[j], max_seq_length) + + chosen_arr[j, :chosen_length] = batch[j]['chosen'][:chosen_length] rejected_arr[j, :rejected_length] = batch[j]['rejected'][:rejected_length] + + chosen_masks[j, :chosen_length] = 1.0 rejected_masks[j, :rejected_length] = 1.0 - - yield (mx.array(chosen_arr), mx.array(rejected_arr), - mx.array(chosen_masks), mx.array(rejected_masks)) - + + yield mx.array(chosen_arr), mx.array(rejected_arr), mx.array(chosen_masks), mx.array(rejected_masks) + if not train: break @@ -225,9 +180,6 @@ def evaluate_dpo( loss_fn: callable = dpo_loss, loss_type="sigmoid", ): - """ - Modified evaluate function for DPO training. - """ all_losses = 0 all_rewards = mx.zeros((2,)) # [chosen_reward, rejected_reward] ntokens = 0 @@ -238,7 +190,6 @@ def evaluate_dpo( index_iterator, iterate_dpo_batches( dataset=dataset, - tokenizer=tokenizer, batch_size=batch_size, max_seq_length=max_seq_length, ), @@ -279,9 +230,6 @@ def train_dpo( training_callback: TrainingCallback = None, loss_type="sigmoid", ): - """ - Modified training function for DPO. - """ print(f"Starting DPO training..., iters: {args.iters}") world = mx.distributed.init() world_size = world.size() @@ -345,7 +293,6 @@ def train_dpo( range(1, args.iters + 1), iterate_dpo_batches( dataset=train_dataset, - tokenizer=tokenizer, batch_size=args.batch_size, max_seq_length=args.max_seq_length, train=True,