Completion only fine-tuning of instruction models with collections of HF datasets (#1103)

- Optional completion only fine-tuning with `--mask-prompt`
- Collections of Hugging Face datasets

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Chime Ogbuji 2025-02-09 23:12:34 -05:00 committed by GitHub
parent 1ced1b00ca
commit 5865899c81
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 199 additions and 85 deletions

View File

@ -76,6 +76,14 @@ You can specify the output location with `--adapter-path`.
You can resume fine-tuning with an existing adapter with You can resume fine-tuning with an existing adapter with
`--resume-adapter-file <path_to_adapters.safetensors>`. `--resume-adapter-file <path_to_adapters.safetensors>`.
#### Prompt Masking
The default training computes a loss for every token in the sample. You can
ignore the prompt and compute loss for just the completion by passing
`--mask-prompt`. Note this is only supported for `chat` and `completion`
datasets. For `chat` datasets the final message in the message list is
considered the completion. See the [dataset section](#Data) for more details.
### Evaluate ### Evaluate
To compute test set perplexity use: To compute test set perplexity use:
@ -290,11 +298,27 @@ hf_dataset:
- Use `prompt_feature` and `completion_feature` to specify keys for a - Use `prompt_feature` and `completion_feature` to specify keys for a
`completions` dataset. Use `text_feature` to specify the key for a `text` `completions` dataset. Use `text_feature` to specify the key for a `text`
dataset. dataset. Use `chat_feature` to specify the key for a chat dataset.
- To specify the train, valid, or test splits, set the corresponding - To specify the train, valid, or test splits, set the corresponding
`{train,valid,test}_split` argument. `{train,valid,test}_split` argument.
You can specify a list of Hugging Face datasets with a list of records each
with the same structure as above. For example:
```yaml
hf_dataset:
- name: "Open-Orca/OpenOrca"
train_split: "train[:90%]"
valid_split: "train[-10%:]"
prompt_feature: "question"
completion_feature: "response"
- name: "trl-lib/ultrafeedback_binarized"
train_split: "train[:90%]"
valid_split: "train[-10%:]"
chat_feature: "chosen"
```
- Arguments specified in `config` will be passed as keyword arguments to - Arguments specified in `config` will be passed as keyword arguments to
[`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset). [`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset).

View File

@ -94,6 +94,14 @@ def build_parser():
choices=["lora", "dora", "full"], choices=["lora", "dora", "full"],
help="Type of fine-tuning to perform: lora, dora, or full.", help="Type of fine-tuning to perform: lora, dora, or full.",
) )
parser.add_argument(
"--mask-prompt",
action="store_true",
help="Mask the prompt in the loss when training",
default=False,
)
parser.add_argument( parser.add_argument(
"--num-layers", "--num-layers",
type=int, type=int,
@ -219,6 +227,7 @@ def train_model(
build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
) )
) )
# Train model # Train model
train( train(
model=model, model=model,

View File

@ -1,5 +1,6 @@
import json import json
from functools import partial from functools import partial
from typing import List
from transformers import AutoTokenizer from transformers import AutoTokenizer
@ -368,3 +369,8 @@ def load_tokenizer(model_path, tokenizer_config_extra={}, eos_token_ids=None):
detokenizer_class, detokenizer_class,
eos_token_ids=eos_token_ids, eos_token_ids=eos_token_ids,
) )
def no_bos_or_eos(sequence: List, bos: int, eos: int) -> List:
removed_bos = sequence if sequence[0] != bos else sequence[1:]
return removed_bos[:-1] if removed_bos[-1] == eos else removed_bos

View File

@ -1,6 +1,8 @@
import itertools
import json import json
import types
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional from typing import Any, Dict, List, Optional
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
@ -34,14 +36,24 @@ class ChatDataset:
https://platform.openai.com/docs/guides/fine-tuning/example-format https://platform.openai.com/docs/guides/fine-tuning/example-format
""" """
def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer): def __init__(
self._data = [ self,
tokenizer.apply_chat_template( data: List[Dict[str, str]],
d["messages"], tokenizer: PreTrainedTokenizer,
tools=d.get("tools", None), chat_key: str = "messages",
) mask_prompt: bool = False,
for d in data ):
] self._data = []
for d in data:
messages = d[chat_key]
tools = d.get("tools", None)
tokens = tokenizer.apply_chat_template(messages, tools=tools)
if mask_prompt:
messages = messages[:-1]
offset = len(tokenizer.apply_chat_template(messages, tools=tools))
self._data.append((tokens, offset))
else:
self._data.append(tokens)
def __getitem__(self, idx: int): def __getitem__(self, idx: int):
return self._data[idx] return self._data[idx]
@ -63,16 +75,36 @@ class CompletionsDataset:
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
prompt_key: str, prompt_key: str,
completion_key: str, completion_key: str,
mask_prompt: bool,
): ):
self._data = [ self._data = []
tokenizer.apply_chat_template( for d in data:
tokens = tokenizer.apply_chat_template(
[ [
{"role": "user", "content": d[prompt_key]}, {"role": "user", "content": d[prompt_key]},
{"role": "assistant", "content": d[completion_key]}, {"role": "assistant", "content": d[completion_key]},
], ],
) )
for d in data if mask_prompt:
] offset = len(
tokenizer.apply_chat_template(
[{"role": "user", "content": d[prompt_key]}]
)
)
self._data.append((tokens, offset))
else:
self._data.append(tokens)
def __getitem__(self, idx: int):
return self._data[idx]
def __len__(self):
return len(self._data)
class ConcatenatedDataset:
def __init__(self, data: List[Any]):
self._data = list(itertools.chain(*data))
def __getitem__(self, idx: int): def __getitem__(self, idx: int):
return self._data[idx] return self._data[idx]
@ -84,18 +116,26 @@ class CompletionsDataset:
def create_dataset( def create_dataset(
data, data,
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
prompt_feature: Optional[str] = None, config,
completion_feature: Optional[str] = None,
): ):
prompt_feature = prompt_feature or "prompt" mask_prompt = getattr(config, "mask_prompt", False)
completion_feature = completion_feature or "completion" prompt_feature = getattr(config, "prompt_feature", "prompt")
text_feature = getattr(config, "text_feature", "text")
completion_feature = getattr(config, "completion_feature", "completion")
chat_feature = getattr(config, "chat_feature", "messages")
sample = data[0] sample = data[0]
if "messages" in sample: if prompt_feature in sample and completion_feature in sample:
return ChatDataset(data, tokenizer) return CompletionsDataset(
elif prompt_feature in sample and completion_feature in sample: data, tokenizer, prompt_feature, completion_feature, mask_prompt
return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature) )
elif "text" in sample: elif chat_feature in sample:
return Dataset(data, tokenizer) return ChatDataset(
data, tokenizer, chat_key=chat_feature, mask_prompt=mask_prompt
)
elif text_feature in sample:
if mask_prompt:
raise ValueError("Prompt masking not supported for text dataset.")
return Dataset(data, tokenizer, text_key=text_feature)
else: else:
raise ValueError( raise ValueError(
"Unsupported data format, check the supported formats here:\n" "Unsupported data format, check the supported formats here:\n"
@ -106,15 +146,14 @@ def create_dataset(
def load_local_dataset( def load_local_dataset(
data_path: Path, data_path: Path,
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
prompt_feature: Optional[str] = None, config,
completion_feature: Optional[str] = None,
): ):
def load_subset(path): def load_subset(path):
if not path.exists(): if not path.exists():
return [] return []
with open(path, "r") as fid: with open(path, "r") as fid:
data = [json.loads(l) for l in fid] data = [json.loads(l) for l in fid]
return create_dataset(data, tokenizer, prompt_feature, completion_feature) return create_dataset(data, tokenizer, config)
names = ("train", "valid", "test") names = ("train", "valid", "test")
train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names] train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names]
@ -124,8 +163,7 @@ def load_local_dataset(
def load_hf_dataset( def load_hf_dataset(
data_id: str, data_id: str,
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
prompt_feature: Optional[str] = None, config,
completion_feature: Optional[str] = None,
): ):
from datasets import exceptions, load_dataset from datasets import exceptions, load_dataset
@ -136,9 +174,7 @@ def load_hf_dataset(
train, valid, test = [ train, valid, test = [
( (
create_dataset( create_dataset(dataset[n], tokenizer, config)
dataset[n], tokenizer, prompt_feature, completion_feature
)
if n in dataset.keys() if n in dataset.keys()
else [] else []
) )
@ -154,42 +190,61 @@ def load_hf_dataset(
def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer): def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
import datasets import datasets
hf_args = args.hf_dataset def create_hf_dataset(dataset_name, config, split, hf_config):
dataset_name = hf_args["name"]
print(f"Loading Hugging Face dataset {dataset_name}.")
text_feature = hf_args.get("text_feature")
prompt_feature = hf_args.get("prompt_feature")
completion_feature = hf_args.get("completion_feature")
def create_hf_dataset(split: str = None):
ds = datasets.load_dataset( ds = datasets.load_dataset(
dataset_name, dataset_name,
split=split, split=split,
**hf_args.get("config", {}), **hf_config,
)
if prompt_feature and completion_feature:
return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature)
elif text_feature:
return Dataset(ds, tokenizer, text_key=text_feature)
else:
raise ValueError(
"Specify either a prompt and completion feature or a text "
"feature for the Hugging Face dataset."
) )
return create_dataset(ds, tokenizer, config)
dataset_collection = args.hf_dataset
if isinstance(dataset_collection, dict):
dataset_collection = [dataset_collection]
collection = []
for ds in dataset_collection:
ds_name = ds["name"]
print(f"Loading Hugging Face dataset {ds_name}.")
ds["mask_prompt"] = getattr(args, "mask_prompt", False)
config = types.SimpleNamespace(**ds)
hf_config = ds.get("config", {})
if args.train: if args.train:
train_split = hf_args.get("train_split", "train[:80%]") train_split = ds.get("train_split", "train[:80%]")
valid_split = hf_args.get("valid_split", "train[-10%:]") valid_split = ds.get("valid_split", "train[-10%:]")
train = create_hf_dataset(split=train_split) train = create_hf_dataset(
valid = create_hf_dataset(split=valid_split) ds_name,
config,
train_split,
hf_config,
)
valid = create_hf_dataset(
ds_name,
config,
valid_split,
hf_config,
)
else: else:
train, valid = [], [] train, valid = [], []
if args.test: if args.test:
test = create_hf_dataset(split=hf_args.get("test_split")) test_split = ds.get("test_split")
test = create_hf_dataset(
ds_name,
config,
test_split,
hf_config,
)
else: else:
test = [] test = []
return train, valid, test collection.append((train, valid, test))
if len(collection) == 1:
return collection[0]
# Otherwise concatenate them
return tuple(map(ConcatenatedDataset, zip(*collection)))
def load_dataset(args, tokenizer: PreTrainedTokenizer): def load_dataset(args, tokenizer: PreTrainedTokenizer):
@ -197,18 +252,11 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer):
train, valid, test = load_custom_hf_dataset(args, tokenizer) train, valid, test = load_custom_hf_dataset(args, tokenizer)
else: else:
data_path = Path(args.data) data_path = Path(args.data)
prompt_feature = getattr(args, "prompt_feature", None)
completion_feature = getattr(args, "completion_feature", None)
if data_path.exists(): if data_path.exists():
train, valid, test = load_local_dataset( train, valid, test = load_local_dataset(data_path, tokenizer, args)
data_path, tokenizer, prompt_feature, completion_feature
)
else: else:
print(f"Loading Hugging Face dataset {args.data}.") print(f"Loading Hugging Face dataset {args.data}.")
train, valid, test = load_hf_dataset( train, valid, test = load_hf_dataset(args.data, tokenizer, args)
args.data, tokenizer, prompt_feature, completion_feature
)
if args.train and len(train) == 0: if args.train and len(train) == 0:
raise ValueError( raise ValueError(

View File

@ -5,13 +5,16 @@ import shutil
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Union from typing import List, Optional, Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np import numpy as np
from mlx.nn.utils import average_gradients from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten from mlx.utils import tree_flatten
from transformers import PreTrainedTokenizer
from .datasets import CompletionsDataset
def grad_checkpoint(layer): def grad_checkpoint(layer):
@ -63,20 +66,30 @@ class TrainingArgs:
) )
def default_loss(model, inputs, targets, lengths): def default_loss(model, batch, lengths):
inputs = batch[:, :-1]
targets = batch[:, 1:]
logits = model(inputs) logits = model(inputs)
logits = logits.astype(mx.float32) logits = logits.astype(mx.float32)
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None] steps = mx.arange(1, targets.shape[1] + 1)
mask = mx.logical_and(steps >= lengths[:, 0:1], steps <= lengths[:, 1:])
ce = nn.losses.cross_entropy(logits, targets) * length_mask ce = nn.losses.cross_entropy(logits, targets) * mask
ntoks = length_mask.sum() ntoks = mask.sum()
ce = ce.sum() / ntoks ce = ce.sum() / ntoks
return ce, ntoks return ce, ntoks
def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): def iterate_batches(
dataset,
tokenizer,
batch_size,
max_seq_length,
train=False,
):
# Sort by length: # Sort by length:
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx])) idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
if len(dataset) < batch_size: if len(dataset) < batch_size:
@ -101,6 +114,10 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
indices = np.random.permutation(len(batch_idx)) indices = np.random.permutation(len(batch_idx))
for i in indices: for i in indices:
batch = [dataset[j] for j in batch_idx[i]] batch = [dataset[j] for j in batch_idx[i]]
if len(batch[0]) == 2:
batch, offsets = zip(*batch)
else:
offsets = [0] * len(batch)
lengths = [len(x) for x in batch] lengths = [len(x) for x in batch]
if max(lengths) > max_seq_length: if max(lengths) > max_seq_length:
print( print(
@ -123,8 +140,7 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
truncated_length # Update lengths to match truncated lengths truncated_length # Update lengths to match truncated lengths
) )
batch = mx.array(batch_arr) batch = mx.array(batch_arr)
yield batch, mx.array(list(zip(offsets, lengths)))
yield batch[:, :-1], batch[:, 1:], mx.array(lengths)
if not train: if not train:
break break

View File

@ -78,14 +78,15 @@ class TestDatasets(unittest.TestCase):
self.assertTrue(isinstance(train, datasets.ChatDataset)) self.assertTrue(isinstance(train, datasets.ChatDataset))
def test_hf(self): def test_hf(self):
args = types.SimpleNamespace( hf_args = {
hf_dataset={
"name": "billsum", "name": "billsum",
"prompt_feature": "text", "prompt_feature": "text",
"completion_feature": "summary", "completion_feature": "summary",
"train_split": "train[:2%]", "train_split": "train[:2%]",
"valid_split": "train[-2%:]", "valid_split": "train[-2%:]",
}, }
args = types.SimpleNamespace(
hf_dataset=hf_args,
test=False, test=False,
train=True, train=True,
) )
@ -97,6 +98,16 @@ class TestDatasets(unittest.TestCase):
self.assertTrue(len(valid[0]) > 0) self.assertTrue(len(valid[0]) > 0)
self.assertEqual(len(test), 0) self.assertEqual(len(test), 0)
args = types.SimpleNamespace(
hf_dataset=[hf_args, hf_args],
test=False,
train=True,
)
train_double, valid_double, test_double = datasets.load_dataset(args, tokenizer)
self.assertEqual(2 * len(train), len(train_double))
self.assertEqual(2 * len(valid), len(valid_double))
self.assertEqual(2 * len(test), len(test_double))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()