mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-28 00:30:09 +08:00
Move response template to LoRA configuration
This commit is contained in:
parent
95e1f22812
commit
7989d0a874
@ -81,6 +81,22 @@ There are custom functions for masking the sequence of tokens associated with th
|
||||
during the loss calculation to ensure the model is not being penalized for not recreating the prompt. To fine-tune
|
||||
with masked input sequences, use the `--mask-inputs` argument.
|
||||
|
||||
This functionality expects a ```response_template``` parameter in the configuration that is either a string representing
|
||||
a [string that indicate the start of the model's response](https://huggingface.co/docs/transformers/en/chat_templating#what-are-generation-prompts)
|
||||
or its corresopnding tokens. This is used to create the mask that excludes the tokens associated from the rest of
|
||||
the sequence from loss calculations. For example (ChatML):
|
||||
|
||||
```yaml
|
||||
response_template: "<|im_start|>assistant"
|
||||
```
|
||||
|
||||
or (for the corresponding tokens of Gemma's response template)
|
||||
|
||||
```yaml
|
||||
response_template: [106, 2516]
|
||||
```
|
||||
|
||||
|
||||
### Evaluate
|
||||
|
||||
To compute test set perplexity use:
|
||||
|
@ -12,7 +12,7 @@ import mlx.optimizers as optim
|
||||
import numpy as np
|
||||
import yaml
|
||||
|
||||
from .tokenizer_utils import TokenizerWrapper
|
||||
from .tokenizer_utils import TokenizerWrapper, no_bos_or_eos
|
||||
from .tuner.datasets import load_dataset
|
||||
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
||||
from .tuner.utils import (
|
||||
@ -63,6 +63,7 @@ CONFIG_DEFAULTS = {
|
||||
"lr_schedule": None,
|
||||
"hf_datasets": None,
|
||||
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
||||
"response_template": None,
|
||||
}
|
||||
|
||||
|
||||
@ -217,6 +218,17 @@ def train_model(
|
||||
adapter_file = adapter_path / "adapters.safetensors"
|
||||
save_config(vars(args), adapter_path / "adapter_config.json")
|
||||
|
||||
if isinstance(args.response_template, str):
|
||||
response_generation_tokens = tokenizer.encode(
|
||||
args.response_template, add_special_tokens=False
|
||||
)
|
||||
else:
|
||||
if not all([item.isinstance(int) for item in args.response_template]):
|
||||
raise ValueError(
|
||||
"Response template must be a list of integers if it is not a string."
|
||||
)
|
||||
response_generation_tokens = args.response_template
|
||||
|
||||
# init training args
|
||||
training_args = TrainingArgs(
|
||||
batch_size=args.batch_size,
|
||||
@ -228,6 +240,9 @@ def train_model(
|
||||
adapter_file=adapter_file,
|
||||
max_seq_length=args.max_seq_length,
|
||||
grad_checkpoint=args.grad_checkpoint,
|
||||
response_generation_tokens=no_bos_or_eos(
|
||||
response_generation_tokens, tokenizer.bos_token_id, tokenizer.eos_token_id
|
||||
),
|
||||
)
|
||||
|
||||
model.train()
|
||||
|
@ -1,5 +1,6 @@
|
||||
import json
|
||||
from functools import partial
|
||||
from typing import List
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
@ -368,3 +369,8 @@ def load_tokenizer(model_path, tokenizer_config_extra={}, eos_token_ids=None):
|
||||
detokenizer_class,
|
||||
eos_token_ids=eos_token_ids,
|
||||
)
|
||||
|
||||
|
||||
def no_bos_or_eos(sequence: List, bos: int, eos: int) -> List:
|
||||
removed_bos = sequence if sequence[0] != bos else sequence[1:]
|
||||
return removed_bos[:-1] if removed_bos[-1] == eos else removed_bos
|
||||
|
@ -64,7 +64,6 @@ class CompletionsDataset:
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
prompt_key: str,
|
||||
completion_key: str,
|
||||
response_template: Union[str, list[int]] = None,
|
||||
):
|
||||
self._data = [
|
||||
tokenizer.apply_chat_template(
|
||||
@ -75,12 +74,6 @@ class CompletionsDataset:
|
||||
)
|
||||
for d in data
|
||||
]
|
||||
if isinstance(response_template, str):
|
||||
self.response_token_ids = self._tokenizer.encode(
|
||||
response_template, add_special_tokens=False
|
||||
)
|
||||
else:
|
||||
self.response_token_ids = response_template
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
return self._data[idx]
|
||||
|
@ -5,7 +5,7 @@ import shutil
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@ -64,6 +64,10 @@ class TrainingArgs:
|
||||
default=False,
|
||||
metadata={"help": "Use gradient checkpointing to reduce memory use."},
|
||||
)
|
||||
response_generation_tokens: Optional[List[int]] = field(
|
||||
default_factory=list,
|
||||
metadata={"help": "List of token ids that mark the beginning of the response"},
|
||||
)
|
||||
|
||||
|
||||
def input_masked_loss(model, inputs, response_prefix_lengths, lengths):
|
||||
@ -114,6 +118,7 @@ def iterate_completion_batches(
|
||||
batch_size: int,
|
||||
max_seq_length: int,
|
||||
train: bool = False,
|
||||
response_generation_tokens: Optional[List[int]] = None,
|
||||
):
|
||||
"""
|
||||
A version of iterate_batches that works with completion datasets, tracks the boundaries between input/output tokens
|
||||
@ -146,14 +151,14 @@ def iterate_completion_batches(
|
||||
if full_sequence[-1] != tokenizer.eos_token_id:
|
||||
full_sequence.append(tokenizer.eos_token_id)
|
||||
batch.append(full_sequence)
|
||||
if len(dataset.response_token_ids) > 1:
|
||||
if len(response_generation_tokens) > 1:
|
||||
response_marker_begin, response_marker_end = contains(
|
||||
dataset.response_token_ids, full_sequence
|
||||
response_generation_tokens, full_sequence
|
||||
)
|
||||
response_prefix_lengths.append(response_marker_end + 1)
|
||||
else:
|
||||
response_marker_begin = full_sequence.index(
|
||||
dataset.response_token_ids[0]
|
||||
response_generation_tokens[0]
|
||||
)
|
||||
response_prefix_lengths.append(response_marker_begin + 1)
|
||||
|
||||
@ -190,7 +195,14 @@ def iterate_completion_batches(
|
||||
break
|
||||
|
||||
|
||||
def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
||||
def iterate_batches(
|
||||
dataset,
|
||||
tokenizer,
|
||||
batch_size,
|
||||
max_seq_length,
|
||||
train=False,
|
||||
response_generation_tokens=None,
|
||||
):
|
||||
# Sort by length:
|
||||
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
|
||||
if len(dataset) < batch_size:
|
||||
@ -253,6 +265,7 @@ def evaluate(
|
||||
max_seq_length=2048,
|
||||
loss: callable = default_loss,
|
||||
iterate_batches: callable = iterate_batches,
|
||||
response_generation_tokens: Optional[List[int]] = None,
|
||||
):
|
||||
all_losses = mx.array(0.0)
|
||||
ntokens = mx.array(0)
|
||||
@ -266,6 +279,7 @@ def evaluate(
|
||||
tokenizer=tokenizer,
|
||||
batch_size=batch_size,
|
||||
max_seq_length=max_seq_length,
|
||||
response_generation_tokens=response_generation_tokens,
|
||||
),
|
||||
):
|
||||
losses, toks = loss(model, *batch)
|
||||
@ -341,6 +355,7 @@ def train(
|
||||
batch_size=args.batch_size,
|
||||
max_seq_length=args.max_seq_length,
|
||||
train=True,
|
||||
response_generation_tokens=args.response_generation_tokens,
|
||||
),
|
||||
):
|
||||
# Report validation loss if needed, the first validation loss
|
||||
@ -356,6 +371,7 @@ def train(
|
||||
num_batches=args.val_batches,
|
||||
max_seq_length=args.max_seq_length,
|
||||
iterate_batches=iterate_batches,
|
||||
response_generation_tokens=args.response_generation_tokens,
|
||||
)
|
||||
val_time = time.perf_counter() - stop
|
||||
if rank == 0:
|
||||
|
Loading…
Reference in New Issue
Block a user