Completion only fine-tuning of instruction models with collections of HF datasets (#1103)

- Optional completion only fine-tuning with `--mask-prompt`
- Collections of Hugging Face datasets

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Chime Ogbuji
2025-02-09 23:12:34 -05:00
committed by GitHub
parent 1ced1b00ca
commit 5865899c81
6 changed files with 199 additions and 85 deletions

View File

@@ -1,6 +1,8 @@
import itertools
import json
import types
from pathlib import Path
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional
from transformers import PreTrainedTokenizer
@@ -34,14 +36,24 @@ class ChatDataset:
https://platform.openai.com/docs/guides/fine-tuning/example-format
"""
def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer):
self._data = [
tokenizer.apply_chat_template(
d["messages"],
tools=d.get("tools", None),
)
for d in data
]
def __init__(
self,
data: List[Dict[str, str]],
tokenizer: PreTrainedTokenizer,
chat_key: str = "messages",
mask_prompt: bool = False,
):
self._data = []
for d in data:
messages = d[chat_key]
tools = d.get("tools", None)
tokens = tokenizer.apply_chat_template(messages, tools=tools)
if mask_prompt:
messages = messages[:-1]
offset = len(tokenizer.apply_chat_template(messages, tools=tools))
self._data.append((tokens, offset))
else:
self._data.append(tokens)
def __getitem__(self, idx: int):
return self._data[idx]
@@ -63,16 +75,36 @@ class CompletionsDataset:
tokenizer: PreTrainedTokenizer,
prompt_key: str,
completion_key: str,
mask_prompt: bool,
):
self._data = [
tokenizer.apply_chat_template(
self._data = []
for d in data:
tokens = tokenizer.apply_chat_template(
[
{"role": "user", "content": d[prompt_key]},
{"role": "assistant", "content": d[completion_key]},
],
)
for d in data
]
if mask_prompt:
offset = len(
tokenizer.apply_chat_template(
[{"role": "user", "content": d[prompt_key]}]
)
)
self._data.append((tokens, offset))
else:
self._data.append(tokens)
def __getitem__(self, idx: int):
return self._data[idx]
def __len__(self):
return len(self._data)
class ConcatenatedDataset:
def __init__(self, data: List[Any]):
self._data = list(itertools.chain(*data))
def __getitem__(self, idx: int):
return self._data[idx]
@@ -84,18 +116,26 @@ class CompletionsDataset:
def create_dataset(
data,
tokenizer: PreTrainedTokenizer,
prompt_feature: Optional[str] = None,
completion_feature: Optional[str] = None,
config,
):
prompt_feature = prompt_feature or "prompt"
completion_feature = completion_feature or "completion"
mask_prompt = getattr(config, "mask_prompt", False)
prompt_feature = getattr(config, "prompt_feature", "prompt")
text_feature = getattr(config, "text_feature", "text")
completion_feature = getattr(config, "completion_feature", "completion")
chat_feature = getattr(config, "chat_feature", "messages")
sample = data[0]
if "messages" in sample:
return ChatDataset(data, tokenizer)
elif prompt_feature in sample and completion_feature in sample:
return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature)
elif "text" in sample:
return Dataset(data, tokenizer)
if prompt_feature in sample and completion_feature in sample:
return CompletionsDataset(
data, tokenizer, prompt_feature, completion_feature, mask_prompt
)
elif chat_feature in sample:
return ChatDataset(
data, tokenizer, chat_key=chat_feature, mask_prompt=mask_prompt
)
elif text_feature in sample:
if mask_prompt:
raise ValueError("Prompt masking not supported for text dataset.")
return Dataset(data, tokenizer, text_key=text_feature)
else:
raise ValueError(
"Unsupported data format, check the supported formats here:\n"
@@ -106,15 +146,14 @@ def create_dataset(
def load_local_dataset(
data_path: Path,
tokenizer: PreTrainedTokenizer,
prompt_feature: Optional[str] = None,
completion_feature: Optional[str] = None,
config,
):
def load_subset(path):
if not path.exists():
return []
with open(path, "r") as fid:
data = [json.loads(l) for l in fid]
return create_dataset(data, tokenizer, prompt_feature, completion_feature)
return create_dataset(data, tokenizer, config)
names = ("train", "valid", "test")
train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names]
@@ -124,8 +163,7 @@ def load_local_dataset(
def load_hf_dataset(
data_id: str,
tokenizer: PreTrainedTokenizer,
prompt_feature: Optional[str] = None,
completion_feature: Optional[str] = None,
config,
):
from datasets import exceptions, load_dataset
@@ -136,9 +174,7 @@ def load_hf_dataset(
train, valid, test = [
(
create_dataset(
dataset[n], tokenizer, prompt_feature, completion_feature
)
create_dataset(dataset[n], tokenizer, config)
if n in dataset.keys()
else []
)
@@ -154,42 +190,61 @@ def load_hf_dataset(
def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
import datasets
hf_args = args.hf_dataset
dataset_name = hf_args["name"]
print(f"Loading Hugging Face dataset {dataset_name}.")
text_feature = hf_args.get("text_feature")
prompt_feature = hf_args.get("prompt_feature")
completion_feature = hf_args.get("completion_feature")
def create_hf_dataset(split: str = None):
def create_hf_dataset(dataset_name, config, split, hf_config):
ds = datasets.load_dataset(
dataset_name,
split=split,
**hf_args.get("config", {}),
**hf_config,
)
if prompt_feature and completion_feature:
return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature)
elif text_feature:
return Dataset(ds, tokenizer, text_key=text_feature)
else:
raise ValueError(
"Specify either a prompt and completion feature or a text "
"feature for the Hugging Face dataset."
return create_dataset(ds, tokenizer, config)
dataset_collection = args.hf_dataset
if isinstance(dataset_collection, dict):
dataset_collection = [dataset_collection]
collection = []
for ds in dataset_collection:
ds_name = ds["name"]
print(f"Loading Hugging Face dataset {ds_name}.")
ds["mask_prompt"] = getattr(args, "mask_prompt", False)
config = types.SimpleNamespace(**ds)
hf_config = ds.get("config", {})
if args.train:
train_split = ds.get("train_split", "train[:80%]")
valid_split = ds.get("valid_split", "train[-10%:]")
train = create_hf_dataset(
ds_name,
config,
train_split,
hf_config,
)
valid = create_hf_dataset(
ds_name,
config,
valid_split,
hf_config,
)
else:
train, valid = [], []
if args.train:
train_split = hf_args.get("train_split", "train[:80%]")
valid_split = hf_args.get("valid_split", "train[-10%:]")
train = create_hf_dataset(split=train_split)
valid = create_hf_dataset(split=valid_split)
else:
train, valid = [], []
if args.test:
test = create_hf_dataset(split=hf_args.get("test_split"))
else:
test = []
if args.test:
test_split = ds.get("test_split")
test = create_hf_dataset(
ds_name,
config,
test_split,
hf_config,
)
else:
test = []
return train, valid, test
collection.append((train, valid, test))
if len(collection) == 1:
return collection[0]
# Otherwise concatenate them
return tuple(map(ConcatenatedDataset, zip(*collection)))
def load_dataset(args, tokenizer: PreTrainedTokenizer):
@@ -197,18 +252,11 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer):
train, valid, test = load_custom_hf_dataset(args, tokenizer)
else:
data_path = Path(args.data)
prompt_feature = getattr(args, "prompt_feature", None)
completion_feature = getattr(args, "completion_feature", None)
if data_path.exists():
train, valid, test = load_local_dataset(
data_path, tokenizer, prompt_feature, completion_feature
)
train, valid, test = load_local_dataset(data_path, tokenizer, args)
else:
print(f"Loading Hugging Face dataset {args.data}.")
train, valid, test = load_hf_dataset(
args.data, tokenizer, prompt_feature, completion_feature
)
train, valid, test = load_hf_dataset(args.data, tokenizer, args)
if args.train and len(train) == 0:
raise ValueError(

View File

@@ -5,13 +5,16 @@ import shutil
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Union
from typing import List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten
from transformers import PreTrainedTokenizer
from .datasets import CompletionsDataset
def grad_checkpoint(layer):
@@ -63,20 +66,30 @@ class TrainingArgs:
)
def default_loss(model, inputs, targets, lengths):
def default_loss(model, batch, lengths):
inputs = batch[:, :-1]
targets = batch[:, 1:]
logits = model(inputs)
logits = logits.astype(mx.float32)
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]
steps = mx.arange(1, targets.shape[1] + 1)
mask = mx.logical_and(steps >= lengths[:, 0:1], steps <= lengths[:, 1:])
ce = nn.losses.cross_entropy(logits, targets) * length_mask
ntoks = length_mask.sum()
ce = nn.losses.cross_entropy(logits, targets) * mask
ntoks = mask.sum()
ce = ce.sum() / ntoks
return ce, ntoks
def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
def iterate_batches(
dataset,
tokenizer,
batch_size,
max_seq_length,
train=False,
):
# Sort by length:
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
if len(dataset) < batch_size:
@@ -101,6 +114,10 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
indices = np.random.permutation(len(batch_idx))
for i in indices:
batch = [dataset[j] for j in batch_idx[i]]
if len(batch[0]) == 2:
batch, offsets = zip(*batch)
else:
offsets = [0] * len(batch)
lengths = [len(x) for x in batch]
if max(lengths) > max_seq_length:
print(
@@ -123,8 +140,7 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
truncated_length # Update lengths to match truncated lengths
)
batch = mx.array(batch_arr)
yield batch[:, :-1], batch[:, 1:], mx.array(lengths)
yield batch, mx.array(list(zip(offsets, lengths)))
if not train:
break