From 7989d0a874adabd770403a95e2fd43858f75c0ff Mon Sep 17 00:00:00 2001 From: Chime Ogbuji Date: Sun, 8 Dec 2024 12:32:52 -0500 Subject: [PATCH] Move response template to LoRA configuration --- llms/mlx_lm/LORA.md | 16 ++++++++++++++++ llms/mlx_lm/lora.py | 17 ++++++++++++++++- llms/mlx_lm/tokenizer_utils.py | 6 ++++++ llms/mlx_lm/tuner/datasets.py | 7 ------- llms/mlx_lm/tuner/trainer.py | 26 +++++++++++++++++++++----- 5 files changed, 59 insertions(+), 13 deletions(-) diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md index d332bfaa..1490c752 100644 --- a/llms/mlx_lm/LORA.md +++ b/llms/mlx_lm/LORA.md @@ -81,6 +81,22 @@ There are custom functions for masking the sequence of tokens associated with th during the loss calculation to ensure the model is not being penalized for not recreating the prompt. To fine-tune with masked input sequences, use the `--mask-inputs` argument. +This functionality expects a ```response_template``` parameter in the configuration that is either a string representing +a [string that indicate the start of the model's response](https://huggingface.co/docs/transformers/en/chat_templating#what-are-generation-prompts) +or its corresopnding tokens. This is used to create the mask that excludes the tokens associated from the rest of +the sequence from loss calculations. For example (ChatML): + +```yaml +response_template: "<|im_start|>assistant" +``` + +or (for the corresponding tokens of Gemma's response template) + +```yaml +response_template: [106, 2516] +``` + + ### Evaluate To compute test set perplexity use: diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 87e3bb0c..192e5e30 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -12,7 +12,7 @@ import mlx.optimizers as optim import numpy as np import yaml -from .tokenizer_utils import TokenizerWrapper +from .tokenizer_utils import TokenizerWrapper, no_bos_or_eos from .tuner.datasets import load_dataset from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train from .tuner.utils import ( @@ -63,6 +63,7 @@ CONFIG_DEFAULTS = { "lr_schedule": None, "hf_datasets": None, "lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}, + "response_template": None, } @@ -217,6 +218,17 @@ def train_model( adapter_file = adapter_path / "adapters.safetensors" save_config(vars(args), adapter_path / "adapter_config.json") + if isinstance(args.response_template, str): + response_generation_tokens = tokenizer.encode( + args.response_template, add_special_tokens=False + ) + else: + if not all([item.isinstance(int) for item in args.response_template]): + raise ValueError( + "Response template must be a list of integers if it is not a string." + ) + response_generation_tokens = args.response_template + # init training args training_args = TrainingArgs( batch_size=args.batch_size, @@ -228,6 +240,9 @@ def train_model( adapter_file=adapter_file, max_seq_length=args.max_seq_length, grad_checkpoint=args.grad_checkpoint, + response_generation_tokens=no_bos_or_eos( + response_generation_tokens, tokenizer.bos_token_id, tokenizer.eos_token_id + ), ) model.train() diff --git a/llms/mlx_lm/tokenizer_utils.py b/llms/mlx_lm/tokenizer_utils.py index 1b5bdd77..de9d5324 100644 --- a/llms/mlx_lm/tokenizer_utils.py +++ b/llms/mlx_lm/tokenizer_utils.py @@ -1,5 +1,6 @@ import json from functools import partial +from typing import List from transformers import AutoTokenizer @@ -368,3 +369,8 @@ def load_tokenizer(model_path, tokenizer_config_extra={}, eos_token_ids=None): detokenizer_class, eos_token_ids=eos_token_ids, ) + + +def no_bos_or_eos(sequence: List, bos: int, eos: int) -> List: + removed_bos = sequence if sequence[0] != bos else sequence[1:] + return removed_bos[:-1] if removed_bos[-1] == eos else removed_bos diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 3ad42b2a..7d0e5026 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -64,7 +64,6 @@ class CompletionsDataset: tokenizer: PreTrainedTokenizer, prompt_key: str, completion_key: str, - response_template: Union[str, list[int]] = None, ): self._data = [ tokenizer.apply_chat_template( @@ -75,12 +74,6 @@ class CompletionsDataset: ) for d in data ] - if isinstance(response_template, str): - self.response_token_ids = self._tokenizer.encode( - response_template, add_special_tokens=False - ) - else: - self.response_token_ids = response_template def __getitem__(self, idx: int): return self._data[idx] diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index 1b202961..6d89477c 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -5,7 +5,7 @@ import shutil import time from dataclasses import dataclass, field from pathlib import Path -from typing import List, Tuple +from typing import List, Optional, Tuple import mlx.core as mx import mlx.nn as nn @@ -64,6 +64,10 @@ class TrainingArgs: default=False, metadata={"help": "Use gradient checkpointing to reduce memory use."}, ) + response_generation_tokens: Optional[List[int]] = field( + default_factory=list, + metadata={"help": "List of token ids that mark the beginning of the response"}, + ) def input_masked_loss(model, inputs, response_prefix_lengths, lengths): @@ -114,6 +118,7 @@ def iterate_completion_batches( batch_size: int, max_seq_length: int, train: bool = False, + response_generation_tokens: Optional[List[int]] = None, ): """ A version of iterate_batches that works with completion datasets, tracks the boundaries between input/output tokens @@ -146,14 +151,14 @@ def iterate_completion_batches( if full_sequence[-1] != tokenizer.eos_token_id: full_sequence.append(tokenizer.eos_token_id) batch.append(full_sequence) - if len(dataset.response_token_ids) > 1: + if len(response_generation_tokens) > 1: response_marker_begin, response_marker_end = contains( - dataset.response_token_ids, full_sequence + response_generation_tokens, full_sequence ) response_prefix_lengths.append(response_marker_end + 1) else: response_marker_begin = full_sequence.index( - dataset.response_token_ids[0] + response_generation_tokens[0] ) response_prefix_lengths.append(response_marker_begin + 1) @@ -190,7 +195,14 @@ def iterate_completion_batches( break -def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): +def iterate_batches( + dataset, + tokenizer, + batch_size, + max_seq_length, + train=False, + response_generation_tokens=None, +): # Sort by length: idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx])) if len(dataset) < batch_size: @@ -253,6 +265,7 @@ def evaluate( max_seq_length=2048, loss: callable = default_loss, iterate_batches: callable = iterate_batches, + response_generation_tokens: Optional[List[int]] = None, ): all_losses = mx.array(0.0) ntokens = mx.array(0) @@ -266,6 +279,7 @@ def evaluate( tokenizer=tokenizer, batch_size=batch_size, max_seq_length=max_seq_length, + response_generation_tokens=response_generation_tokens, ), ): losses, toks = loss(model, *batch) @@ -341,6 +355,7 @@ def train( batch_size=args.batch_size, max_seq_length=args.max_seq_length, train=True, + response_generation_tokens=args.response_generation_tokens, ), ): # Report validation loss if needed, the first validation loss @@ -356,6 +371,7 @@ def train( num_batches=args.val_batches, max_seq_length=args.max_seq_length, iterate_batches=iterate_batches, + response_generation_tokens=args.response_generation_tokens, ) val_time = time.perf_counter() - stop if rank == 0: