Move response template to LoRA configuration

This commit is contained in:
Chime Ogbuji 2024-12-08 12:32:52 -05:00 committed by Awni Hannun
parent 95e1f22812
commit 7989d0a874
5 changed files with 59 additions and 13 deletions

View File

@ -81,6 +81,22 @@ There are custom functions for masking the sequence of tokens associated with th
during the loss calculation to ensure the model is not being penalized for not recreating the prompt. To fine-tune
with masked input sequences, use the `--mask-inputs` argument.
This functionality expects a ```response_template``` parameter in the configuration that is either a string representing
a [string that indicate the start of the model's response](https://huggingface.co/docs/transformers/en/chat_templating#what-are-generation-prompts)
or its corresopnding tokens. This is used to create the mask that excludes the tokens associated from the rest of
the sequence from loss calculations. For example (ChatML):
```yaml
response_template: "<|im_start|>assistant"
```
or (for the corresponding tokens of Gemma's response template)
```yaml
response_template: [106, 2516]
```
### Evaluate
To compute test set perplexity use:

View File

@ -12,7 +12,7 @@ import mlx.optimizers as optim
import numpy as np
import yaml
from .tokenizer_utils import TokenizerWrapper
from .tokenizer_utils import TokenizerWrapper, no_bos_or_eos
from .tuner.datasets import load_dataset
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tuner.utils import (
@ -63,6 +63,7 @@ CONFIG_DEFAULTS = {
"lr_schedule": None,
"hf_datasets": None,
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
"response_template": None,
}
@ -217,6 +218,17 @@ def train_model(
adapter_file = adapter_path / "adapters.safetensors"
save_config(vars(args), adapter_path / "adapter_config.json")
if isinstance(args.response_template, str):
response_generation_tokens = tokenizer.encode(
args.response_template, add_special_tokens=False
)
else:
if not all([item.isinstance(int) for item in args.response_template]):
raise ValueError(
"Response template must be a list of integers if it is not a string."
)
response_generation_tokens = args.response_template
# init training args
training_args = TrainingArgs(
batch_size=args.batch_size,
@ -228,6 +240,9 @@ def train_model(
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint,
response_generation_tokens=no_bos_or_eos(
response_generation_tokens, tokenizer.bos_token_id, tokenizer.eos_token_id
),
)
model.train()

View File

@ -1,5 +1,6 @@
import json
from functools import partial
from typing import List
from transformers import AutoTokenizer
@ -368,3 +369,8 @@ def load_tokenizer(model_path, tokenizer_config_extra={}, eos_token_ids=None):
detokenizer_class,
eos_token_ids=eos_token_ids,
)
def no_bos_or_eos(sequence: List, bos: int, eos: int) -> List:
removed_bos = sequence if sequence[0] != bos else sequence[1:]
return removed_bos[:-1] if removed_bos[-1] == eos else removed_bos

View File

@ -64,7 +64,6 @@ class CompletionsDataset:
tokenizer: PreTrainedTokenizer,
prompt_key: str,
completion_key: str,
response_template: Union[str, list[int]] = None,
):
self._data = [
tokenizer.apply_chat_template(
@ -75,12 +74,6 @@ class CompletionsDataset:
)
for d in data
]
if isinstance(response_template, str):
self.response_token_ids = self._tokenizer.encode(
response_template, add_special_tokens=False
)
else:
self.response_token_ids = response_template
def __getitem__(self, idx: int):
return self._data[idx]

View File

@ -5,7 +5,7 @@ import shutil
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Tuple
from typing import List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
@ -64,6 +64,10 @@ class TrainingArgs:
default=False,
metadata={"help": "Use gradient checkpointing to reduce memory use."},
)
response_generation_tokens: Optional[List[int]] = field(
default_factory=list,
metadata={"help": "List of token ids that mark the beginning of the response"},
)
def input_masked_loss(model, inputs, response_prefix_lengths, lengths):
@ -114,6 +118,7 @@ def iterate_completion_batches(
batch_size: int,
max_seq_length: int,
train: bool = False,
response_generation_tokens: Optional[List[int]] = None,
):
"""
A version of iterate_batches that works with completion datasets, tracks the boundaries between input/output tokens
@ -146,14 +151,14 @@ def iterate_completion_batches(
if full_sequence[-1] != tokenizer.eos_token_id:
full_sequence.append(tokenizer.eos_token_id)
batch.append(full_sequence)
if len(dataset.response_token_ids) > 1:
if len(response_generation_tokens) > 1:
response_marker_begin, response_marker_end = contains(
dataset.response_token_ids, full_sequence
response_generation_tokens, full_sequence
)
response_prefix_lengths.append(response_marker_end + 1)
else:
response_marker_begin = full_sequence.index(
dataset.response_token_ids[0]
response_generation_tokens[0]
)
response_prefix_lengths.append(response_marker_begin + 1)
@ -190,7 +195,14 @@ def iterate_completion_batches(
break
def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
def iterate_batches(
dataset,
tokenizer,
batch_size,
max_seq_length,
train=False,
response_generation_tokens=None,
):
# Sort by length:
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
if len(dataset) < batch_size:
@ -253,6 +265,7 @@ def evaluate(
max_seq_length=2048,
loss: callable = default_loss,
iterate_batches: callable = iterate_batches,
response_generation_tokens: Optional[List[int]] = None,
):
all_losses = mx.array(0.0)
ntokens = mx.array(0)
@ -266,6 +279,7 @@ def evaluate(
tokenizer=tokenizer,
batch_size=batch_size,
max_seq_length=max_seq_length,
response_generation_tokens=response_generation_tokens,
),
):
losses, toks = loss(model, *batch)
@ -341,6 +355,7 @@ def train(
batch_size=args.batch_size,
max_seq_length=args.max_seq_length,
train=True,
response_generation_tokens=args.response_generation_tokens,
),
):
# Report validation loss if needed, the first validation loss
@ -356,6 +371,7 @@ def train(
num_batches=args.val_batches,
max_seq_length=args.max_seq_length,
iterate_batches=iterate_batches,
response_generation_tokens=args.response_generation_tokens,
)
val_time = time.perf_counter() - stop
if rank == 0: