mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 06:54:39 +08:00
Move response template to LoRA configuration
This commit is contained in:
parent
95e1f22812
commit
7989d0a874
@ -81,6 +81,22 @@ There are custom functions for masking the sequence of tokens associated with th
|
|||||||
during the loss calculation to ensure the model is not being penalized for not recreating the prompt. To fine-tune
|
during the loss calculation to ensure the model is not being penalized for not recreating the prompt. To fine-tune
|
||||||
with masked input sequences, use the `--mask-inputs` argument.
|
with masked input sequences, use the `--mask-inputs` argument.
|
||||||
|
|
||||||
|
This functionality expects a ```response_template``` parameter in the configuration that is either a string representing
|
||||||
|
a [string that indicate the start of the model's response](https://huggingface.co/docs/transformers/en/chat_templating#what-are-generation-prompts)
|
||||||
|
or its corresopnding tokens. This is used to create the mask that excludes the tokens associated from the rest of
|
||||||
|
the sequence from loss calculations. For example (ChatML):
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
response_template: "<|im_start|>assistant"
|
||||||
|
```
|
||||||
|
|
||||||
|
or (for the corresponding tokens of Gemma's response template)
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
response_template: [106, 2516]
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
### Evaluate
|
### Evaluate
|
||||||
|
|
||||||
To compute test set perplexity use:
|
To compute test set perplexity use:
|
||||||
|
@ -12,7 +12,7 @@ import mlx.optimizers as optim
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from .tokenizer_utils import TokenizerWrapper
|
from .tokenizer_utils import TokenizerWrapper, no_bos_or_eos
|
||||||
from .tuner.datasets import load_dataset
|
from .tuner.datasets import load_dataset
|
||||||
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
||||||
from .tuner.utils import (
|
from .tuner.utils import (
|
||||||
@ -63,6 +63,7 @@ CONFIG_DEFAULTS = {
|
|||||||
"lr_schedule": None,
|
"lr_schedule": None,
|
||||||
"hf_datasets": None,
|
"hf_datasets": None,
|
||||||
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
||||||
|
"response_template": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -217,6 +218,17 @@ def train_model(
|
|||||||
adapter_file = adapter_path / "adapters.safetensors"
|
adapter_file = adapter_path / "adapters.safetensors"
|
||||||
save_config(vars(args), adapter_path / "adapter_config.json")
|
save_config(vars(args), adapter_path / "adapter_config.json")
|
||||||
|
|
||||||
|
if isinstance(args.response_template, str):
|
||||||
|
response_generation_tokens = tokenizer.encode(
|
||||||
|
args.response_template, add_special_tokens=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if not all([item.isinstance(int) for item in args.response_template]):
|
||||||
|
raise ValueError(
|
||||||
|
"Response template must be a list of integers if it is not a string."
|
||||||
|
)
|
||||||
|
response_generation_tokens = args.response_template
|
||||||
|
|
||||||
# init training args
|
# init training args
|
||||||
training_args = TrainingArgs(
|
training_args = TrainingArgs(
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
@ -228,6 +240,9 @@ def train_model(
|
|||||||
adapter_file=adapter_file,
|
adapter_file=adapter_file,
|
||||||
max_seq_length=args.max_seq_length,
|
max_seq_length=args.max_seq_length,
|
||||||
grad_checkpoint=args.grad_checkpoint,
|
grad_checkpoint=args.grad_checkpoint,
|
||||||
|
response_generation_tokens=no_bos_or_eos(
|
||||||
|
response_generation_tokens, tokenizer.bos_token_id, tokenizer.eos_token_id
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
@ -368,3 +369,8 @@ def load_tokenizer(model_path, tokenizer_config_extra={}, eos_token_ids=None):
|
|||||||
detokenizer_class,
|
detokenizer_class,
|
||||||
eos_token_ids=eos_token_ids,
|
eos_token_ids=eos_token_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def no_bos_or_eos(sequence: List, bos: int, eos: int) -> List:
|
||||||
|
removed_bos = sequence if sequence[0] != bos else sequence[1:]
|
||||||
|
return removed_bos[:-1] if removed_bos[-1] == eos else removed_bos
|
||||||
|
@ -64,7 +64,6 @@ class CompletionsDataset:
|
|||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
prompt_key: str,
|
prompt_key: str,
|
||||||
completion_key: str,
|
completion_key: str,
|
||||||
response_template: Union[str, list[int]] = None,
|
|
||||||
):
|
):
|
||||||
self._data = [
|
self._data = [
|
||||||
tokenizer.apply_chat_template(
|
tokenizer.apply_chat_template(
|
||||||
@ -75,12 +74,6 @@ class CompletionsDataset:
|
|||||||
)
|
)
|
||||||
for d in data
|
for d in data
|
||||||
]
|
]
|
||||||
if isinstance(response_template, str):
|
|
||||||
self.response_token_ids = self._tokenizer.encode(
|
|
||||||
response_template, add_special_tokens=False
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.response_token_ids = response_template
|
|
||||||
|
|
||||||
def __getitem__(self, idx: int):
|
def __getitem__(self, idx: int):
|
||||||
return self._data[idx]
|
return self._data[idx]
|
||||||
|
@ -5,7 +5,7 @@ import shutil
|
|||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -64,6 +64,10 @@ class TrainingArgs:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Use gradient checkpointing to reduce memory use."},
|
metadata={"help": "Use gradient checkpointing to reduce memory use."},
|
||||||
)
|
)
|
||||||
|
response_generation_tokens: Optional[List[int]] = field(
|
||||||
|
default_factory=list,
|
||||||
|
metadata={"help": "List of token ids that mark the beginning of the response"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def input_masked_loss(model, inputs, response_prefix_lengths, lengths):
|
def input_masked_loss(model, inputs, response_prefix_lengths, lengths):
|
||||||
@ -114,6 +118,7 @@ def iterate_completion_batches(
|
|||||||
batch_size: int,
|
batch_size: int,
|
||||||
max_seq_length: int,
|
max_seq_length: int,
|
||||||
train: bool = False,
|
train: bool = False,
|
||||||
|
response_generation_tokens: Optional[List[int]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
A version of iterate_batches that works with completion datasets, tracks the boundaries between input/output tokens
|
A version of iterate_batches that works with completion datasets, tracks the boundaries between input/output tokens
|
||||||
@ -146,14 +151,14 @@ def iterate_completion_batches(
|
|||||||
if full_sequence[-1] != tokenizer.eos_token_id:
|
if full_sequence[-1] != tokenizer.eos_token_id:
|
||||||
full_sequence.append(tokenizer.eos_token_id)
|
full_sequence.append(tokenizer.eos_token_id)
|
||||||
batch.append(full_sequence)
|
batch.append(full_sequence)
|
||||||
if len(dataset.response_token_ids) > 1:
|
if len(response_generation_tokens) > 1:
|
||||||
response_marker_begin, response_marker_end = contains(
|
response_marker_begin, response_marker_end = contains(
|
||||||
dataset.response_token_ids, full_sequence
|
response_generation_tokens, full_sequence
|
||||||
)
|
)
|
||||||
response_prefix_lengths.append(response_marker_end + 1)
|
response_prefix_lengths.append(response_marker_end + 1)
|
||||||
else:
|
else:
|
||||||
response_marker_begin = full_sequence.index(
|
response_marker_begin = full_sequence.index(
|
||||||
dataset.response_token_ids[0]
|
response_generation_tokens[0]
|
||||||
)
|
)
|
||||||
response_prefix_lengths.append(response_marker_begin + 1)
|
response_prefix_lengths.append(response_marker_begin + 1)
|
||||||
|
|
||||||
@ -190,7 +195,14 @@ def iterate_completion_batches(
|
|||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
def iterate_batches(
|
||||||
|
dataset,
|
||||||
|
tokenizer,
|
||||||
|
batch_size,
|
||||||
|
max_seq_length,
|
||||||
|
train=False,
|
||||||
|
response_generation_tokens=None,
|
||||||
|
):
|
||||||
# Sort by length:
|
# Sort by length:
|
||||||
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
|
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
|
||||||
if len(dataset) < batch_size:
|
if len(dataset) < batch_size:
|
||||||
@ -253,6 +265,7 @@ def evaluate(
|
|||||||
max_seq_length=2048,
|
max_seq_length=2048,
|
||||||
loss: callable = default_loss,
|
loss: callable = default_loss,
|
||||||
iterate_batches: callable = iterate_batches,
|
iterate_batches: callable = iterate_batches,
|
||||||
|
response_generation_tokens: Optional[List[int]] = None,
|
||||||
):
|
):
|
||||||
all_losses = mx.array(0.0)
|
all_losses = mx.array(0.0)
|
||||||
ntokens = mx.array(0)
|
ntokens = mx.array(0)
|
||||||
@ -266,6 +279,7 @@ def evaluate(
|
|||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
max_seq_length=max_seq_length,
|
max_seq_length=max_seq_length,
|
||||||
|
response_generation_tokens=response_generation_tokens,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
losses, toks = loss(model, *batch)
|
losses, toks = loss(model, *batch)
|
||||||
@ -341,6 +355,7 @@ def train(
|
|||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
max_seq_length=args.max_seq_length,
|
max_seq_length=args.max_seq_length,
|
||||||
train=True,
|
train=True,
|
||||||
|
response_generation_tokens=args.response_generation_tokens,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
# Report validation loss if needed, the first validation loss
|
# Report validation loss if needed, the first validation loss
|
||||||
@ -356,6 +371,7 @@ def train(
|
|||||||
num_batches=args.val_batches,
|
num_batches=args.val_batches,
|
||||||
max_seq_length=args.max_seq_length,
|
max_seq_length=args.max_seq_length,
|
||||||
iterate_batches=iterate_batches,
|
iterate_batches=iterate_batches,
|
||||||
|
response_generation_tokens=args.response_generation_tokens,
|
||||||
)
|
)
|
||||||
val_time = time.perf_counter() - stop
|
val_time = time.perf_counter() - stop
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
|
Loading…
Reference in New Issue
Block a user