mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
Completion only fine-tuning of instruction models with collections of HF datasets (#1103)
- Optional completion only fine-tuning with `--mask-prompt` - Collections of Hugging Face datasets --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
1ced1b00ca
commit
5865899c81
@ -76,6 +76,14 @@ You can specify the output location with `--adapter-path`.
|
||||
You can resume fine-tuning with an existing adapter with
|
||||
`--resume-adapter-file <path_to_adapters.safetensors>`.
|
||||
|
||||
#### Prompt Masking
|
||||
|
||||
The default training computes a loss for every token in the sample. You can
|
||||
ignore the prompt and compute loss for just the completion by passing
|
||||
`--mask-prompt`. Note this is only supported for `chat` and `completion`
|
||||
datasets. For `chat` datasets the final message in the message list is
|
||||
considered the completion. See the [dataset section](#Data) for more details.
|
||||
|
||||
### Evaluate
|
||||
|
||||
To compute test set perplexity use:
|
||||
@ -290,11 +298,27 @@ hf_dataset:
|
||||
|
||||
- Use `prompt_feature` and `completion_feature` to specify keys for a
|
||||
`completions` dataset. Use `text_feature` to specify the key for a `text`
|
||||
dataset.
|
||||
dataset. Use `chat_feature` to specify the key for a chat dataset.
|
||||
|
||||
- To specify the train, valid, or test splits, set the corresponding
|
||||
`{train,valid,test}_split` argument.
|
||||
|
||||
You can specify a list of Hugging Face datasets with a list of records each
|
||||
with the same structure as above. For example:
|
||||
|
||||
```yaml
|
||||
hf_dataset:
|
||||
- name: "Open-Orca/OpenOrca"
|
||||
train_split: "train[:90%]"
|
||||
valid_split: "train[-10%:]"
|
||||
prompt_feature: "question"
|
||||
completion_feature: "response"
|
||||
- name: "trl-lib/ultrafeedback_binarized"
|
||||
train_split: "train[:90%]"
|
||||
valid_split: "train[-10%:]"
|
||||
chat_feature: "chosen"
|
||||
```
|
||||
|
||||
- Arguments specified in `config` will be passed as keyword arguments to
|
||||
[`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset).
|
||||
|
||||
|
@ -94,6 +94,14 @@ def build_parser():
|
||||
choices=["lora", "dora", "full"],
|
||||
help="Type of fine-tuning to perform: lora, dora, or full.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--mask-prompt",
|
||||
action="store_true",
|
||||
help="Mask the prompt in the loss when training",
|
||||
default=False,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-layers",
|
||||
type=int,
|
||||
@ -219,6 +227,7 @@ def train_model(
|
||||
build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
|
||||
)
|
||||
)
|
||||
|
||||
# Train model
|
||||
train(
|
||||
model=model,
|
||||
|
@ -1,5 +1,6 @@
|
||||
import json
|
||||
from functools import partial
|
||||
from typing import List
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
@ -368,3 +369,8 @@ def load_tokenizer(model_path, tokenizer_config_extra={}, eos_token_ids=None):
|
||||
detokenizer_class,
|
||||
eos_token_ids=eos_token_ids,
|
||||
)
|
||||
|
||||
|
||||
def no_bos_or_eos(sequence: List, bos: int, eos: int) -> List:
|
||||
removed_bos = sequence if sequence[0] != bos else sequence[1:]
|
||||
return removed_bos[:-1] if removed_bos[-1] == eos else removed_bos
|
||||
|
@ -1,6 +1,8 @@
|
||||
import itertools
|
||||
import json
|
||||
import types
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
@ -34,14 +36,24 @@ class ChatDataset:
|
||||
https://platform.openai.com/docs/guides/fine-tuning/example-format
|
||||
"""
|
||||
|
||||
def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer):
|
||||
self._data = [
|
||||
tokenizer.apply_chat_template(
|
||||
d["messages"],
|
||||
tools=d.get("tools", None),
|
||||
)
|
||||
for d in data
|
||||
]
|
||||
def __init__(
|
||||
self,
|
||||
data: List[Dict[str, str]],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
chat_key: str = "messages",
|
||||
mask_prompt: bool = False,
|
||||
):
|
||||
self._data = []
|
||||
for d in data:
|
||||
messages = d[chat_key]
|
||||
tools = d.get("tools", None)
|
||||
tokens = tokenizer.apply_chat_template(messages, tools=tools)
|
||||
if mask_prompt:
|
||||
messages = messages[:-1]
|
||||
offset = len(tokenizer.apply_chat_template(messages, tools=tools))
|
||||
self._data.append((tokens, offset))
|
||||
else:
|
||||
self._data.append(tokens)
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
return self._data[idx]
|
||||
@ -63,16 +75,36 @@ class CompletionsDataset:
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
prompt_key: str,
|
||||
completion_key: str,
|
||||
mask_prompt: bool,
|
||||
):
|
||||
self._data = [
|
||||
tokenizer.apply_chat_template(
|
||||
self._data = []
|
||||
for d in data:
|
||||
tokens = tokenizer.apply_chat_template(
|
||||
[
|
||||
{"role": "user", "content": d[prompt_key]},
|
||||
{"role": "assistant", "content": d[completion_key]},
|
||||
],
|
||||
)
|
||||
for d in data
|
||||
]
|
||||
if mask_prompt:
|
||||
offset = len(
|
||||
tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": d[prompt_key]}]
|
||||
)
|
||||
)
|
||||
self._data.append((tokens, offset))
|
||||
else:
|
||||
self._data.append(tokens)
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
return self._data[idx]
|
||||
|
||||
def __len__(self):
|
||||
return len(self._data)
|
||||
|
||||
|
||||
class ConcatenatedDataset:
|
||||
def __init__(self, data: List[Any]):
|
||||
self._data = list(itertools.chain(*data))
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
return self._data[idx]
|
||||
@ -84,18 +116,26 @@ class CompletionsDataset:
|
||||
def create_dataset(
|
||||
data,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
prompt_feature: Optional[str] = None,
|
||||
completion_feature: Optional[str] = None,
|
||||
config,
|
||||
):
|
||||
prompt_feature = prompt_feature or "prompt"
|
||||
completion_feature = completion_feature or "completion"
|
||||
mask_prompt = getattr(config, "mask_prompt", False)
|
||||
prompt_feature = getattr(config, "prompt_feature", "prompt")
|
||||
text_feature = getattr(config, "text_feature", "text")
|
||||
completion_feature = getattr(config, "completion_feature", "completion")
|
||||
chat_feature = getattr(config, "chat_feature", "messages")
|
||||
sample = data[0]
|
||||
if "messages" in sample:
|
||||
return ChatDataset(data, tokenizer)
|
||||
elif prompt_feature in sample and completion_feature in sample:
|
||||
return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature)
|
||||
elif "text" in sample:
|
||||
return Dataset(data, tokenizer)
|
||||
if prompt_feature in sample and completion_feature in sample:
|
||||
return CompletionsDataset(
|
||||
data, tokenizer, prompt_feature, completion_feature, mask_prompt
|
||||
)
|
||||
elif chat_feature in sample:
|
||||
return ChatDataset(
|
||||
data, tokenizer, chat_key=chat_feature, mask_prompt=mask_prompt
|
||||
)
|
||||
elif text_feature in sample:
|
||||
if mask_prompt:
|
||||
raise ValueError("Prompt masking not supported for text dataset.")
|
||||
return Dataset(data, tokenizer, text_key=text_feature)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported data format, check the supported formats here:\n"
|
||||
@ -106,15 +146,14 @@ def create_dataset(
|
||||
def load_local_dataset(
|
||||
data_path: Path,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
prompt_feature: Optional[str] = None,
|
||||
completion_feature: Optional[str] = None,
|
||||
config,
|
||||
):
|
||||
def load_subset(path):
|
||||
if not path.exists():
|
||||
return []
|
||||
with open(path, "r") as fid:
|
||||
data = [json.loads(l) for l in fid]
|
||||
return create_dataset(data, tokenizer, prompt_feature, completion_feature)
|
||||
return create_dataset(data, tokenizer, config)
|
||||
|
||||
names = ("train", "valid", "test")
|
||||
train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names]
|
||||
@ -124,8 +163,7 @@ def load_local_dataset(
|
||||
def load_hf_dataset(
|
||||
data_id: str,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
prompt_feature: Optional[str] = None,
|
||||
completion_feature: Optional[str] = None,
|
||||
config,
|
||||
):
|
||||
from datasets import exceptions, load_dataset
|
||||
|
||||
@ -136,9 +174,7 @@ def load_hf_dataset(
|
||||
|
||||
train, valid, test = [
|
||||
(
|
||||
create_dataset(
|
||||
dataset[n], tokenizer, prompt_feature, completion_feature
|
||||
)
|
||||
create_dataset(dataset[n], tokenizer, config)
|
||||
if n in dataset.keys()
|
||||
else []
|
||||
)
|
||||
@ -154,42 +190,61 @@ def load_hf_dataset(
|
||||
def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
|
||||
import datasets
|
||||
|
||||
hf_args = args.hf_dataset
|
||||
dataset_name = hf_args["name"]
|
||||
print(f"Loading Hugging Face dataset {dataset_name}.")
|
||||
text_feature = hf_args.get("text_feature")
|
||||
prompt_feature = hf_args.get("prompt_feature")
|
||||
completion_feature = hf_args.get("completion_feature")
|
||||
|
||||
def create_hf_dataset(split: str = None):
|
||||
def create_hf_dataset(dataset_name, config, split, hf_config):
|
||||
ds = datasets.load_dataset(
|
||||
dataset_name,
|
||||
split=split,
|
||||
**hf_args.get("config", {}),
|
||||
**hf_config,
|
||||
)
|
||||
if prompt_feature and completion_feature:
|
||||
return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature)
|
||||
elif text_feature:
|
||||
return Dataset(ds, tokenizer, text_key=text_feature)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Specify either a prompt and completion feature or a text "
|
||||
"feature for the Hugging Face dataset."
|
||||
return create_dataset(ds, tokenizer, config)
|
||||
|
||||
dataset_collection = args.hf_dataset
|
||||
if isinstance(dataset_collection, dict):
|
||||
dataset_collection = [dataset_collection]
|
||||
|
||||
collection = []
|
||||
for ds in dataset_collection:
|
||||
ds_name = ds["name"]
|
||||
print(f"Loading Hugging Face dataset {ds_name}.")
|
||||
ds["mask_prompt"] = getattr(args, "mask_prompt", False)
|
||||
config = types.SimpleNamespace(**ds)
|
||||
hf_config = ds.get("config", {})
|
||||
if args.train:
|
||||
train_split = ds.get("train_split", "train[:80%]")
|
||||
valid_split = ds.get("valid_split", "train[-10%:]")
|
||||
train = create_hf_dataset(
|
||||
ds_name,
|
||||
config,
|
||||
train_split,
|
||||
hf_config,
|
||||
)
|
||||
valid = create_hf_dataset(
|
||||
ds_name,
|
||||
config,
|
||||
valid_split,
|
||||
hf_config,
|
||||
)
|
||||
else:
|
||||
train, valid = [], []
|
||||
|
||||
if args.train:
|
||||
train_split = hf_args.get("train_split", "train[:80%]")
|
||||
valid_split = hf_args.get("valid_split", "train[-10%:]")
|
||||
train = create_hf_dataset(split=train_split)
|
||||
valid = create_hf_dataset(split=valid_split)
|
||||
else:
|
||||
train, valid = [], []
|
||||
if args.test:
|
||||
test = create_hf_dataset(split=hf_args.get("test_split"))
|
||||
else:
|
||||
test = []
|
||||
if args.test:
|
||||
test_split = ds.get("test_split")
|
||||
test = create_hf_dataset(
|
||||
ds_name,
|
||||
config,
|
||||
test_split,
|
||||
hf_config,
|
||||
)
|
||||
else:
|
||||
test = []
|
||||
|
||||
return train, valid, test
|
||||
collection.append((train, valid, test))
|
||||
|
||||
if len(collection) == 1:
|
||||
return collection[0]
|
||||
|
||||
# Otherwise concatenate them
|
||||
return tuple(map(ConcatenatedDataset, zip(*collection)))
|
||||
|
||||
|
||||
def load_dataset(args, tokenizer: PreTrainedTokenizer):
|
||||
@ -197,18 +252,11 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer):
|
||||
train, valid, test = load_custom_hf_dataset(args, tokenizer)
|
||||
else:
|
||||
data_path = Path(args.data)
|
||||
|
||||
prompt_feature = getattr(args, "prompt_feature", None)
|
||||
completion_feature = getattr(args, "completion_feature", None)
|
||||
if data_path.exists():
|
||||
train, valid, test = load_local_dataset(
|
||||
data_path, tokenizer, prompt_feature, completion_feature
|
||||
)
|
||||
train, valid, test = load_local_dataset(data_path, tokenizer, args)
|
||||
else:
|
||||
print(f"Loading Hugging Face dataset {args.data}.")
|
||||
train, valid, test = load_hf_dataset(
|
||||
args.data, tokenizer, prompt_feature, completion_feature
|
||||
)
|
||||
train, valid, test = load_hf_dataset(args.data, tokenizer, args)
|
||||
|
||||
if args.train and len(train) == 0:
|
||||
raise ValueError(
|
||||
|
@ -5,13 +5,16 @@ import shutil
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from mlx.nn.utils import average_gradients
|
||||
from mlx.utils import tree_flatten
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from .datasets import CompletionsDataset
|
||||
|
||||
|
||||
def grad_checkpoint(layer):
|
||||
@ -63,20 +66,30 @@ class TrainingArgs:
|
||||
)
|
||||
|
||||
|
||||
def default_loss(model, inputs, targets, lengths):
|
||||
def default_loss(model, batch, lengths):
|
||||
inputs = batch[:, :-1]
|
||||
targets = batch[:, 1:]
|
||||
|
||||
logits = model(inputs)
|
||||
logits = logits.astype(mx.float32)
|
||||
|
||||
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]
|
||||
steps = mx.arange(1, targets.shape[1] + 1)
|
||||
mask = mx.logical_and(steps >= lengths[:, 0:1], steps <= lengths[:, 1:])
|
||||
|
||||
ce = nn.losses.cross_entropy(logits, targets) * length_mask
|
||||
ntoks = length_mask.sum()
|
||||
ce = nn.losses.cross_entropy(logits, targets) * mask
|
||||
ntoks = mask.sum()
|
||||
ce = ce.sum() / ntoks
|
||||
|
||||
return ce, ntoks
|
||||
|
||||
|
||||
def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
||||
def iterate_batches(
|
||||
dataset,
|
||||
tokenizer,
|
||||
batch_size,
|
||||
max_seq_length,
|
||||
train=False,
|
||||
):
|
||||
# Sort by length:
|
||||
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
|
||||
if len(dataset) < batch_size:
|
||||
@ -101,6 +114,10 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
|
||||
indices = np.random.permutation(len(batch_idx))
|
||||
for i in indices:
|
||||
batch = [dataset[j] for j in batch_idx[i]]
|
||||
if len(batch[0]) == 2:
|
||||
batch, offsets = zip(*batch)
|
||||
else:
|
||||
offsets = [0] * len(batch)
|
||||
lengths = [len(x) for x in batch]
|
||||
if max(lengths) > max_seq_length:
|
||||
print(
|
||||
@ -123,8 +140,7 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
|
||||
truncated_length # Update lengths to match truncated lengths
|
||||
)
|
||||
batch = mx.array(batch_arr)
|
||||
|
||||
yield batch[:, :-1], batch[:, 1:], mx.array(lengths)
|
||||
yield batch, mx.array(list(zip(offsets, lengths)))
|
||||
|
||||
if not train:
|
||||
break
|
||||
|
@ -78,14 +78,15 @@ class TestDatasets(unittest.TestCase):
|
||||
self.assertTrue(isinstance(train, datasets.ChatDataset))
|
||||
|
||||
def test_hf(self):
|
||||
hf_args = {
|
||||
"name": "billsum",
|
||||
"prompt_feature": "text",
|
||||
"completion_feature": "summary",
|
||||
"train_split": "train[:2%]",
|
||||
"valid_split": "train[-2%:]",
|
||||
}
|
||||
args = types.SimpleNamespace(
|
||||
hf_dataset={
|
||||
"name": "billsum",
|
||||
"prompt_feature": "text",
|
||||
"completion_feature": "summary",
|
||||
"train_split": "train[:2%]",
|
||||
"valid_split": "train[-2%:]",
|
||||
},
|
||||
hf_dataset=hf_args,
|
||||
test=False,
|
||||
train=True,
|
||||
)
|
||||
@ -97,6 +98,16 @@ class TestDatasets(unittest.TestCase):
|
||||
self.assertTrue(len(valid[0]) > 0)
|
||||
self.assertEqual(len(test), 0)
|
||||
|
||||
args = types.SimpleNamespace(
|
||||
hf_dataset=[hf_args, hf_args],
|
||||
test=False,
|
||||
train=True,
|
||||
)
|
||||
train_double, valid_double, test_double = datasets.load_dataset(args, tokenizer)
|
||||
self.assertEqual(2 * len(train), len(train_double))
|
||||
self.assertEqual(2 * len(valid), len(valid_double))
|
||||
self.assertEqual(2 * len(test), len(test_double))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user