From 6e9542a934be18fa9a22f9b937fcac7f29f011f0 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sun, 9 Feb 2025 17:31:23 -0800 Subject: [PATCH] put offset in prompt, simplify --- llms/mlx_lm/lora.py | 36 +--------- llms/mlx_lm/tuner/datasets.py | 85 ++++++++++++---------- llms/mlx_lm/tuner/trainer.py | 128 ++++------------------------------ 3 files changed, 64 insertions(+), 185 deletions(-) diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 5ac58aa1..abc5dfa9 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -12,7 +12,7 @@ import mlx.optimizers as optim import numpy as np import yaml -from .tokenizer_utils import TokenizerWrapper, no_bos_or_eos +from .tokenizer_utils import TokenizerWrapper from .tuner.datasets import load_dataset from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train from .tuner.utils import ( @@ -62,7 +62,6 @@ CONFIG_DEFAULTS = { "grad_checkpoint": False, "lr_schedule": None, "lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}, - "response_template": None, } @@ -97,10 +96,9 @@ def build_parser(): ) parser.add_argument( - "--mask-inputs", - dest="mask_inputs", + "--mask-prompt", action="store_true", - help="Whether to mask the inputs when training. Default is False.", + help="Mask the prompt in the loss when training", default=False, ) @@ -182,13 +180,6 @@ def train_model( valid_set, training_callback: TrainingCallback = None, ): - from .tuner.trainer import ( - default_loss, - input_masked_loss, - iterate_batches, - iterate_completion_batches, - ) - model.freeze() if args.fine_tune_type == "full": for l in model.layers[-min(args.num_layers, 0) :]: @@ -217,17 +208,6 @@ def train_model( adapter_file = adapter_path / "adapters.safetensors" save_config(vars(args), adapter_path / "adapter_config.json") - if isinstance(args.response_template, str): - response_generation_tokens = tokenizer.encode( - args.response_template, add_special_tokens=False - ) - else: - if not all([item.isinstance(int) for item in args.response_template]): - raise ValueError( - "Response template must be a list of integers if it is not a string." - ) - response_generation_tokens = args.response_template - # init training args training_args = TrainingArgs( batch_size=args.batch_size, @@ -239,9 +219,6 @@ def train_model( adapter_file=adapter_file, max_seq_length=args.max_seq_length, grad_checkpoint=args.grad_checkpoint, - response_generation_tokens=no_bos_or_eos( - response_generation_tokens, tokenizer.bos_token_id, tokenizer.eos_token_id - ), ) model.train() @@ -251,9 +228,6 @@ def train_model( ) ) - if args.mask_inputs: - print("Masking inputs..") - # Train model train( model=model, @@ -263,10 +237,6 @@ def train_model( train_dataset=train_set, val_dataset=valid_set, training_callback=training_callback, - iterate_batches=( - iterate_completion_batches if args.mask_inputs else iterate_batches - ), - loss=input_masked_loss if args.mask_inputs else default_loss, ) diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 1f990fb7..174d05ca 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -40,14 +40,19 @@ class ChatDataset: data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer, chat_key: str = "messages", + mask_prompt: bool = False, ): - self._data = [ - tokenizer.apply_chat_template( - d[chat_key], - tools=d.get("tools", None), - ) - for d in data - ] + self._data = [] + for d in data: + messages = d[chat_key] + tools = d.get("tools", None) + tokens = tokenizer.apply_chat_template(messages, tools=tools) + if mask_prompt: + messages = messages[:-1] + offset = len(tokenizer.apply_chat_template(messages, tools=tools)) + self._data.append((tokens, offset)) + else: + self._data.append(tokens) def __getitem__(self, idx: int): return self._data[idx] @@ -69,16 +74,25 @@ class CompletionsDataset: tokenizer: PreTrainedTokenizer, prompt_key: str, completion_key: str, + mask_prompt: bool, ): - self._data = [ - tokenizer.apply_chat_template( + self._data = [] + for d in data: + tokens = tokenizer.apply_chat_template( [ {"role": "user", "content": d[prompt_key]}, {"role": "assistant", "content": d[completion_key]}, ], ) - for d in data - ] + if mask_prompt: + offset = len( + tokenizer.apply_chat_template( + [{"role": "user", "content": d[prompt_key]}] + ) + ) + self._data.append((tokens, offset)) + else: + self._data.append(tokens) def __getitem__(self, idx: int): return self._data[idx] @@ -101,17 +115,21 @@ class ConcatenatedDataset: def create_dataset( data, tokenizer: PreTrainedTokenizer, - prompt_feature: Optional[str] = None, - completion_feature: Optional[str] = None, + config: Dict, ): - prompt_feature = prompt_feature or "prompt" - completion_feature = completion_feature or "completion" + mask_prompt = getattr(config, "mask_prompt", False) + prompt_feature = getattr(config, "prompt_feature", "prompt") + completion_feature = getattr(config, "completion_feature", "completion") sample = data[0] if "messages" in sample: - return ChatDataset(data, tokenizer) + return ChatDataset(data, tokenizer, mask_prompt=mask_prompt) elif prompt_feature in sample and completion_feature in sample: - return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature) + return CompletionsDataset( + data, tokenizer, prompt_feature, completion_feature, mask_prompt + ) elif "text" in sample: + if mask_prompt: + raise ValueError("Prompt masking not supported for text dataset.") return Dataset(data, tokenizer) else: raise ValueError( @@ -123,15 +141,14 @@ def create_dataset( def load_local_dataset( data_path: Path, tokenizer: PreTrainedTokenizer, - prompt_feature: Optional[str] = None, - completion_feature: Optional[str] = None, + config: Dict, ): def load_subset(path): if not path.exists(): return [] with open(path, "r") as fid: data = [json.loads(l) for l in fid] - return create_dataset(data, tokenizer, prompt_feature, completion_feature) + return create_dataset(data, tokenizer, config) names = ("train", "valid", "test") train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names] @@ -141,8 +158,7 @@ def load_local_dataset( def load_hf_dataset( data_id: str, tokenizer: PreTrainedTokenizer, - prompt_feature: Optional[str] = None, - completion_feature: Optional[str] = None, + config: Dict, ): from datasets import exceptions, load_dataset @@ -153,9 +169,7 @@ def load_hf_dataset( train, valid, test = [ ( - create_dataset( - dataset[n], tokenizer, prompt_feature, completion_feature - ) + create_dataset(dataset[n], tokenizer, config) if n in dataset.keys() else [] ) @@ -186,10 +200,16 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer): **config, ) if prompt_feature and completion_feature: - return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature) + return CompletionsDataset( + data, tokenizer, prompt_feature, completion_feature, mask_prompt + ) elif chat_feature: - return ChatDataset(ds, tokenizer, chat_key=chat_feature) + return ChatDataset( + ds, tokenizer, chat_key=chat_feature, mask_prompt=mask_prompt + ) elif text_feature: + if mask_prompt: + raise ValueError("Prompt masking not supported for text dataset.") return Dataset(ds, tokenizer, text_key=text_feature) else: raise ValueError( @@ -262,18 +282,11 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer): train, valid, test = load_custom_hf_dataset(args, tokenizer) else: data_path = Path(args.data) - - prompt_feature = getattr(args, "prompt_feature", None) - completion_feature = getattr(args, "completion_feature", None) if data_path.exists(): - train, valid, test = load_local_dataset( - data_path, tokenizer, prompt_feature, completion_feature - ) + train, valid, test = load_local_dataset(data_path, tokenizer, args) else: print(f"Loading Hugging Face dataset {args.data}.") - train, valid, test = load_hf_dataset( - args.data, tokenizer, prompt_feature, completion_feature - ) + train, valid, test = load_hf_dataset(args.data, tokenizer, args) if args.train and len(train) == 0: raise ValueError( diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index 6d89477c..f2f0cb5d 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -64,39 +64,20 @@ class TrainingArgs: default=False, metadata={"help": "Use gradient checkpointing to reduce memory use."}, ) - response_generation_tokens: Optional[List[int]] = field( - default_factory=list, - metadata={"help": "List of token ids that mark the beginning of the response"}, - ) -def input_masked_loss(model, inputs, response_prefix_lengths, lengths): - shifted_inputs = inputs[:, :-1] - shifted_labels = inputs[:, 1:] - logits = model(shifted_inputs) - logits = logits.astype(mx.float32) +def default_loss(model, batch, lengths): + inputs = batch[:, :-1] + targets = batch[:, 1:] - mask_width = shifted_inputs.shape[1] - token_indices = mx.arange(mask_width)[None, :] - mask = mx.logical_and( - token_indices >= response_prefix_lengths[:, None], - token_indices < lengths[:, None], - ) - - ce = nn.losses.cross_entropy(logits, shifted_labels) * mask - ntoks = mask.sum() - ce = ce.sum() / ntoks - return ce, ntoks - - -def default_loss(model, inputs, targets, lengths): logits = model(inputs) logits = logits.astype(mx.float32) - length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None] + steps = mx.arange(1, targets.shape[1] + 1) + mask = mx.logical_and(steps >= lengths[:, 0:1], steps <= lengths[:, 1:]) - ce = nn.losses.cross_entropy(logits, targets) * length_mask - ntoks = length_mask.sum() + ce = nn.losses.cross_entropy(logits, targets) * mask + ntoks = mask.sum() ce = ce.sum() / ntoks return ce, ntoks @@ -112,96 +93,12 @@ def contains(small_list: List, big_list: List) -> Tuple[int, int]: return ind, ind + small_list_length - 1 -def iterate_completion_batches( - dataset: CompletionsDataset, - tokenizer: PreTrainedTokenizer, - batch_size: int, - max_seq_length: int, - train: bool = False, - response_generation_tokens: Optional[List[int]] = None, -): - """ - A version of iterate_batches that works with completion datasets, tracks the boundaries between input/output tokens - and returns the lengths of input tokens as well as that of the full sequences. - """ - idx = sorted(range(len(dataset)), key=lambda i: len(dataset[i])) - if len(dataset) < batch_size: - raise ValueError( - f"Dataset must have at least batch_size={batch_size}" - f" examples but only has {len(dataset)}." - ) - - # If running in distributed mode (N machines) then each one should skip N-1 - # samples - step = mx.distributed.init().size() - if batch_size % step != 0: - raise ValueError("The batch size must be divisible by the number of workers") - # Make the batches: - batch_idx = [ - idx[i : i + batch_size : step] - for i in range(0, len(idx) - batch_size + 1, batch_size) - ] - while True: - indices = np.random.permutation(len(batch_idx)) - for i in indices: - response_prefix_lengths = [] - batch = [] - for j in batch_idx[i]: - full_sequence = dataset.get_item(j, tokenize=True) - if full_sequence[-1] != tokenizer.eos_token_id: - full_sequence.append(tokenizer.eos_token_id) - batch.append(full_sequence) - if len(response_generation_tokens) > 1: - response_marker_begin, response_marker_end = contains( - response_generation_tokens, full_sequence - ) - response_prefix_lengths.append(response_marker_end + 1) - else: - response_marker_begin = full_sequence.index( - response_generation_tokens[0] - ) - response_prefix_lengths.append(response_marker_begin + 1) - - lengths = [len(x) for x in batch] - - if max(lengths) > max_seq_length: - print( - f"[WARNING] Some sequences are longer than {max_seq_length} tokens. " - f"The longest sentence {max(lengths)} will be truncated to {max_seq_length}. " - "Consider pre-splitting your data to save memory." - ) - - # Pad to the nearest multiple of 8 or the maximum length - pad_to = 8 - max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to) - max_length_in_batch = min(max_length_in_batch, max_seq_length) - - batch_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32) - for j in range(batch_size // step): - response_prefix_length = response_prefix_lengths[j] - truncated_length = min(lengths[j], max_seq_length) - batch_arr[j, response_prefix_length:truncated_length] = batch[j][ - response_prefix_length:truncated_length - ] - lengths[j] = ( - truncated_length # Update lengths to match truncated lengths - ) - - yield mx.array(batch_arr), mx.array(response_prefix_lengths), mx.array( - lengths - ) - - if not train: - break - - def iterate_batches( dataset, tokenizer, batch_size, max_seq_length, train=False, - response_generation_tokens=None, ): # Sort by length: idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx])) @@ -227,6 +124,10 @@ def iterate_batches( indices = np.random.permutation(len(batch_idx)) for i in indices: batch = [dataset[j] for j in batch_idx[i]] + if len(batch[0]) == 2: + batch, offsets = zip(*batch) + else: + offsets = [0] * len(batch) lengths = [len(x) for x in batch] if max(lengths) > max_seq_length: print( @@ -249,8 +150,7 @@ def iterate_batches( truncated_length # Update lengths to match truncated lengths ) batch = mx.array(batch_arr) - - yield batch[:, :-1], batch[:, 1:], mx.array(lengths) + yield batch, mx.array(list(zip(offsets, lengths))) if not train: break @@ -265,7 +165,6 @@ def evaluate( max_seq_length=2048, loss: callable = default_loss, iterate_batches: callable = iterate_batches, - response_generation_tokens: Optional[List[int]] = None, ): all_losses = mx.array(0.0) ntokens = mx.array(0) @@ -279,7 +178,6 @@ def evaluate( tokenizer=tokenizer, batch_size=batch_size, max_seq_length=max_seq_length, - response_generation_tokens=response_generation_tokens, ), ): losses, toks = loss(model, *batch) @@ -355,7 +253,6 @@ def train( batch_size=args.batch_size, max_seq_length=args.max_seq_length, train=True, - response_generation_tokens=args.response_generation_tokens, ), ): # Report validation loss if needed, the first validation loss @@ -371,7 +268,6 @@ def train( num_batches=args.val_batches, max_seq_length=args.max_seq_length, iterate_batches=iterate_batches, - response_generation_tokens=args.response_generation_tokens, ) val_time = time.perf_counter() - stop if rank == 0: