From 68403f5577b1fac4c34d2805aff3b448351a72be Mon Sep 17 00:00:00 2001 From: paNikitin <115797306+paNikitin@users.noreply.github.com> Date: Sun, 23 Feb 2025 12:31:44 +0300 Subject: [PATCH] added cot loss masking training --- llms/mlx_lm/examples/lora_config.yaml | 6 + llms/mlx_lm/lora.py | 46 +++++--- llms/mlx_lm/tuner/new_tokens.py | 162 ++++++++++++++++++++++++++ llms/mlx_lm/tuner/trainer.py | 118 +++++++++++++------ 4 files changed, 279 insertions(+), 53 deletions(-) create mode 100644 llms/mlx_lm/tuner/new_tokens.py diff --git a/llms/mlx_lm/examples/lora_config.yaml b/llms/mlx_lm/examples/lora_config.yaml index 530272c7..ca799f8d 100644 --- a/llms/mlx_lm/examples/lora_config.yaml +++ b/llms/mlx_lm/examples/lora_config.yaml @@ -64,6 +64,12 @@ lora_parameters: scale: 20.0 dropout: 0.0 +# cot loss masking training +# cot: +# use_cot: true +# special: true +# additional_tokens: ["[REASONING]", "[DATA]"] + # Schedule can only be specified in a config file, uncomment to use. #lr_schedule: # name: cosine_decay diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index def3b6dd..45119025 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -62,6 +62,7 @@ CONFIG_DEFAULTS = { "grad_checkpoint": False, "lr_schedule": None, "lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}, + "cot": False, } @@ -78,7 +79,6 @@ def build_parser(): "--train", action="store_true", help="Do training", - default=None, ) parser.add_argument( "--data", @@ -94,14 +94,6 @@ def build_parser(): choices=["lora", "dora", "full"], help="Type of fine-tuning to perform: lora, dora, or full.", ) - - parser.add_argument( - "--mask-prompt", - action="store_true", - help="Mask the prompt in the loss when training", - default=False, - ) - parser.add_argument( "--num-layers", type=int, @@ -144,7 +136,6 @@ def build_parser(): "--test", action="store_true", help="Evaluate on the test set after training", - default=None, ) parser.add_argument( "--test-batches", @@ -166,9 +157,13 @@ def build_parser(): "--grad-checkpoint", action="store_true", help="Use gradient checkpointing to reduce memory use.", - default=None, ) parser.add_argument("--seed", type=int, help="The PRNG seed") + parser.add_argument( + "--cot", + type=bool, + help="Use CoT loss masking", + ) return parser @@ -181,14 +176,8 @@ def train_model( training_callback: TrainingCallback = None, ): model.freeze() - if args.num_layers > len(model.layers): - raise ValueError( - f"Requested to train {args.num_layers} layers " - f"but the model only has {len(model.layers)} layers." - ) - if args.fine_tune_type == "full": - for l in model.layers[-max(args.num_layers, 0) :]: + for l in model.layers[-min(args.num_layers, 0) :]: l.unfreeze() elif args.fine_tune_type in ["lora", "dora"]: # Convert linear layers to lora/dora layers and unfreeze in the process @@ -225,10 +214,13 @@ def train_model( adapter_file=adapter_file, max_seq_length=args.max_seq_length, grad_checkpoint=args.grad_checkpoint, + cot=(cot := args.cot), ) model.train() - opt = optim.Adam( + # todo optimizer from args + + opt = optim.AdamW( learning_rate=( build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate ) @@ -269,6 +261,21 @@ def run(args, training_callback: TrainingCallback = None): print("Loading pretrained model") model, tokenizer = load(args.model) + if cot := args.cot: + print("Using CoT loss masking") + if tokens := cot.get("additional_tokens"): + from .tuner.new_tokens import implement_new_tokens + + special = False + if (special_arg := cot.get("special")) and isinstance(special_arg, bool): + print("Updating model and tokenizer with new special tokens") + special = special_arg + else: + print("Updating model and tokenizer with new tokens") + model, tokenizer = implement_new_tokens( + model=model, tokenizer=tokenizer, tokens=tokens, special=special + ) + print("Loading datasets") train_set, valid_set, test_set = load_dataset(args, tokenizer) @@ -293,6 +300,7 @@ def main(): parser = build_parser() args = parser.parse_args() config = args.config + args = vars(args) if config: print("Loading configuration file", config) diff --git a/llms/mlx_lm/tuner/new_tokens.py b/llms/mlx_lm/tuner/new_tokens.py new file mode 100644 index 00000000..489a6705 --- /dev/null +++ b/llms/mlx_lm/tuner/new_tokens.py @@ -0,0 +1,162 @@ +import mlx.nn as nn +import mlx.core as mx +from mlx_lm.tokenizer_utils import TokenizerWrapper + + +def resize_embeddings(model: nn.Module, tokenizer: TokenizerWrapper) -> nn.Module: + """ + Resizes model embeddings to accommodate new tokens + """ + old_embedding = model.model.embed_tokens + + old_vocab_size = old_embedding.num_embeddings + new_vocab_size = len(tokenizer._tokenizer) + + if old_vocab_size != new_vocab_size: + if new_vocab_size < old_vocab_size: + print( + "Warning: New vocab size is smaller than original. Proceeding with trim." + ) + + # check if QuantizedEmbedding has required attributes for dequantization + try: + dequantized_weights = mx.dequantize( + old_embedding.weight, + scales=old_embedding.scales, + biases=old_embedding.biases, + group_size=old_embedding.group_size, + bits=old_embedding.bits, + ) + except AttributeError as e: + print(f"Error: Cannot dequantize embed_tokens. Missing attributes: {e}") + print("Falling back to random weights for embed_tokens.") + dequantized_weights = mx.random.normal( + (old_vocab_size, old_embedding.dims), loc=0.0, scale=0.02 + ) + + # resize embed_tokens + new_embedding = nn.Embedding(new_vocab_size, old_embedding.dims) + new_weights = mx.zeros((new_vocab_size, old_embedding.dims)) + min_vocab_size = min(old_vocab_size, new_vocab_size) + new_weights[:min_vocab_size] = dequantized_weights[:min_vocab_size] + if new_vocab_size > old_vocab_size: + new_weights[old_vocab_size:] = mx.random.normal( + (new_vocab_size - old_vocab_size, old_embedding.dims), + loc=0.0, + scale=0.02, + ) + new_embedding.weight = new_weights + model.model.embed_tokens = new_embedding + + # attention layers handling + if hasattr(model, "args") and getattr(model.args, "tie_word_embeddings", False): + model.model.embed_tokens.weight = new_weights + elif hasattr(model, "lm_head"): + old_lm_head = model.lm_head + if isinstance(old_lm_head, nn.QuantizedLinear): + # resize nn.QuantizedLinear + output_dims, compressed_input_dims = old_lm_head.weight.shape + bits = old_lm_head.bits + input_dims = compressed_input_dims * (32 // bits) + + # dequantize lm_head weights + try: + dequantized_lm_weights = mx.dequantize( + old_lm_head.weight, + scales=old_lm_head.scales, + biases=old_lm_head.biases, + group_size=old_lm_head.group_size, + bits=old_lm_head.bits, + ) + except AttributeError as e: + print(f"Error: Cannot dequantize lm_head. Missing attributes: {e}") + print("Falling back to random weights for lm_head.") + dequantized_lm_weights = mx.random.normal( + (output_dims, input_dims), loc=0.0, scale=0.02 + ) + + new_lm_head = nn.QuantizedLinear( + input_dims=input_dims, + output_dims=new_vocab_size, + bias="bias" in old_lm_head, + group_size=old_lm_head.group_size, + bits=old_lm_head.bits, + ) + new_weights_lm = mx.zeros((new_vocab_size, input_dims)) + new_weights_lm[:min_vocab_size] = dequantized_lm_weights[ + :min_vocab_size + ] + if new_vocab_size > output_dims: + new_weights_lm[output_dims:] = mx.random.normal( + (new_vocab_size - output_dims, input_dims), loc=0.0, scale=0.02 + ) + new_lm_head.weight, new_lm_head.scales, new_lm_head.biases = ( + mx.quantize( + new_weights_lm, new_lm_head.group_size, new_lm_head.bits + ) + ) + if "bias" in old_lm_head: + new_lm_head.bias = mx.zeros((new_vocab_size,)) + new_lm_head.bias[:min_vocab_size] = old_lm_head.bias[ + :min_vocab_size + ] + else: + # resize nn.Linear + new_lm_head = nn.Linear( + old_lm_head.input_dims, new_vocab_size, bias="bias" in old_lm_head + ) + new_weights_lm = mx.zeros((new_vocab_size, old_lm_head.input_dims)) + min_vocab_size = min(old_lm_head.weight.shape[0], new_vocab_size) + new_weights_lm[:min_vocab_size] = old_lm_head.weight[:min_vocab_size] + if new_vocab_size > old_lm_head.weight.shape[0]: + new_weights_lm[old_lm_head.weight.shape[0] :] = mx.random.normal( + ( + new_vocab_size - old_lm_head.weight.shape[0], + old_lm_head.input_dims, + ), + loc=0.0, + scale=0.02, + ) + new_lm_head.weight = new_weights_lm + # todo typechecking + if "bias" in old_lm_head: + new_lm_head.bias = mx.zeros((new_vocab_size,)) + new_lm_head.bias[:min_vocab_size] = old_lm_head.bias[ + :min_vocab_size + ] + + model.lm_head = new_lm_head + else: + print("Vocab already sized right.") + return model + + +def update_tokenizer( + tokenizer: TokenizerWrapper, tokens: list[str], special: bool +) -> TokenizerWrapper: + """ + Appends new tokens to the end of the tokenizer vocab + """ + if special: + # todo TokenizerWrapper access method + tokenizer._tokenizer.add_special_tokens({"additional_special_tokens": tokens}) + print(f"Tokenizer updated with special tokens: {tokens}") + print(f"Tokenizer vocab size after append: {len(tokenizer._tokenizer)}") + else: + # todo add regular tokens + pass + return tokenizer + + +def implement_new_tokens( + model: nn.Module, + tokenizer: TokenizerWrapper, + tokens: list[str], + special: bool = False, +) -> tuple[nn.Module, TokenizerWrapper]: + """ + Update model`s tokenizer and embeddings with new tokens accordingly + """ + tokenizer = update_tokenizer(tokenizer=tokenizer, tokens=tokens, special=special) + model = resize_embeddings(model=model, tokenizer=tokenizer) + return model, tokenizer \ No newline at end of file diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index 64e26af8..4f5db2ea 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -1,20 +1,20 @@ # Copyright © 2024 Apple Inc. +from functools import partial import glob import shutil import time from dataclasses import dataclass, field from pathlib import Path -from typing import List, Optional, Tuple +from typing import Union import mlx.core as mx import mlx.nn as nn import numpy as np from mlx.nn.utils import average_gradients from mlx.utils import tree_flatten -from transformers import PreTrainedTokenizer -from .datasets import CompletionsDataset +from mlx_lm.tokenizer_utils import TokenizerWrapper def grad_checkpoint(layer): @@ -64,32 +64,80 @@ class TrainingArgs: default=False, metadata={"help": "Use gradient checkpointing to reduce memory use."}, ) + cot: bool = field( + default=False, + metadata={"help": "Use CoT loss masking with positioning penalty"}, + ) -def default_loss(model, batch, lengths): - inputs = batch[:, :-1] - targets = batch[:, 1:] - +def default_loss(model, inputs, targets, lengths): logits = model(inputs) logits = logits.astype(mx.float32) - steps = mx.arange(1, targets.shape[1] + 1) - mask = mx.logical_and(steps >= lengths[:, 0:1], steps <= lengths[:, 1:]) + length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None] - ce = nn.losses.cross_entropy(logits, targets) * mask - ntoks = mask.sum() + ce = nn.losses.cross_entropy(logits, targets) * length_mask + ntoks = length_mask.sum() ce = ce.sum() / ntoks return ce, ntoks -def iterate_batches( - dataset, - tokenizer, - batch_size, - max_seq_length, - train=False, -): +@dataclass +class CotTrainingArgs: + cot: bool = False + reasoning_token: str = "[REASONING]" + data_token: str = "[DATA]" + + +def cot_loss( + model: nn.Module, + inputs: mx.array, + targets: mx.array, + lengths: int, + tokenizer: TokenizerWrapper, + penalty: mx.float32 = 10.0, +) -> tuple[mx.array, mx.array]: + logits = model(inputs).astype(mx.float32) + + reasoning_token_id = tokenizer.encode(CotTrainingArgs.reasoning_token)[0] + data_token_id = tokenizer.encode(CotTrainingArgs.data_token)[0] + + reasoning_positions = mx.argmax(targets == reasoning_token_id, axis=1) + data_positions = mx.argmax(targets == data_token_id, axis=1) + + seq_indices = mx.arange(targets.shape[1])[None, :] + + # base CoT mask: starts at [DATA] + cot_mask = (seq_indices >= data_positions[:, None]).astype(mx.float32) + + # length mask: limits to non-padded regions + length_mask = (seq_indices < lengths[:, None]).astype(mx.float32) + + # combine masks: only include tokens after [DATA] AND within sequence length + loss_mask = cot_mask * length_mask + + # validate sequence structure + valid_seq = ( + (reasoning_positions < data_positions) + & mx.any(targets == reasoning_token_id, axis=1) + & mx.any(targets == data_token_id, axis=1) + ) + + # compute base cross-entropy loss + ce = nn.losses.cross_entropy(logits, targets) + + # masking loss before [DATA]; applying penalty for invalid seq + valid_loss = (ce * loss_mask).sum(axis=1) / (mx.sum(loss_mask, axis=1) + 1e-8) + final_loss = mx.where(valid_seq, valid_loss, penalty) # 10.0 as invalid penalty + loss = mx.mean(final_loss) + + valid_tokens = mx.sum(loss_mask) + 1e-8 + + return loss, valid_tokens + + +def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): # Sort by length: idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx])) if len(dataset) < batch_size: @@ -114,10 +162,6 @@ def iterate_batches( indices = np.random.permutation(len(batch_idx)) for i in indices: batch = [dataset[j] for j in batch_idx[i]] - if len(batch[0]) == 2: - batch, offsets = zip(*batch) - else: - offsets = [0] * len(batch) lengths = [len(x) for x in batch] if max(lengths) > max_seq_length: print( @@ -140,7 +184,8 @@ def iterate_batches( truncated_length # Update lengths to match truncated lengths ) batch = mx.array(batch_arr) - yield batch, mx.array(list(zip(offsets, lengths))) + + yield batch[:, :-1], batch[:, 1:], mx.array(lengths) if not train: break @@ -156,8 +201,8 @@ def evaluate( loss: callable = default_loss, iterate_batches: callable = iterate_batches, ): - all_losses = mx.array(0.0) - ntokens = mx.array(0) + all_losses = 0 + ntokens = 0 index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) @@ -213,6 +258,11 @@ def train( if args.grad_checkpoint: grad_checkpoint(model.layers[0]) + if args.cot: + loss = partial(cot_loss, tokenizer=tokenizer, penalty=10.0) + else: + loss = default_loss + state = [model.state, optimizer.state] def step(batch): @@ -233,8 +283,8 @@ def train( n_tokens = 0 steps = 0 trained_tokens = 0 - train_time = 0 # Main training loop + start = time.perf_counter() for it, batch in zip( range(1, args.iters + 1), iterate_batches( @@ -245,11 +295,10 @@ def train( train=True, ), ): - tic = time.perf_counter() # Report validation loss if needed, the first validation loss # is always measured before any training. if it == 1 or it % args.steps_per_eval == 0 or it == args.iters: - tic = time.perf_counter() + stop = time.perf_counter() val_loss = evaluate( model=model, dataset=val_dataset, @@ -260,7 +309,7 @@ def train( max_seq_length=args.max_seq_length, iterate_batches=iterate_batches, ) - val_time = time.perf_counter() - tic + val_time = time.perf_counter() - stop if rank == 0: print( f"Iter {it}: " @@ -277,23 +326,24 @@ def train( } training_callback.on_val_loss_report(val_info) - tic = time.perf_counter() + start = time.perf_counter() lvalue, toks = step(batch) losses += lvalue n_tokens += toks steps += 1 mx.eval(state, losses, n_tokens) - train_time += time.perf_counter() - tic # Report training loss if needed if it % args.steps_per_report == 0 or it == args.iters: + stop = time.perf_counter() + train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item() train_loss /= steps * mx.distributed.init().size() n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item() learning_rate = optimizer.learning_rate.item() - it_sec = args.steps_per_report / train_time - tokens_sec = float(n_tokens) / train_time + it_sec = args.steps_per_report / (stop - start) + tokens_sec = float(n_tokens) / (stop - start) trained_tokens += n_tokens peak_mem = mx.metal.get_peak_memory() / 1e9 if rank == 0: @@ -322,7 +372,7 @@ def train( losses = 0 n_tokens = 0 steps = 0 - train_time = 0 + start = time.perf_counter() # Save adapter weights if it % args.steps_per_save == 0: