From 5865899c81d35ea48c6b69071d7fe61a46880d30 Mon Sep 17 00:00:00 2001 From: Chime Ogbuji Date: Sun, 9 Feb 2025 23:12:34 -0500 Subject: [PATCH] Completion only fine-tuning of instruction models with collections of HF datasets (#1103) - Optional completion only fine-tuning with `--mask-prompt` - Collections of Hugging Face datasets --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/LORA.md | 26 ++++- llms/mlx_lm/lora.py | 9 ++ llms/mlx_lm/tokenizer_utils.py | 6 ++ llms/mlx_lm/tuner/datasets.py | 186 +++++++++++++++++++++------------ llms/mlx_lm/tuner/trainer.py | 32 ++++-- llms/tests/test_datsets.py | 25 +++-- 6 files changed, 199 insertions(+), 85 deletions(-) diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md index 9eac9d7f..e863abc4 100644 --- a/llms/mlx_lm/LORA.md +++ b/llms/mlx_lm/LORA.md @@ -76,6 +76,14 @@ You can specify the output location with `--adapter-path`. You can resume fine-tuning with an existing adapter with `--resume-adapter-file `. +#### Prompt Masking + +The default training computes a loss for every token in the sample. You can +ignore the prompt and compute loss for just the completion by passing +`--mask-prompt`. Note this is only supported for `chat` and `completion` +datasets. For `chat` datasets the final message in the message list is +considered the completion. See the [dataset section](#Data) for more details. + ### Evaluate To compute test set perplexity use: @@ -290,11 +298,27 @@ hf_dataset: - Use `prompt_feature` and `completion_feature` to specify keys for a `completions` dataset. Use `text_feature` to specify the key for a `text` - dataset. + dataset. Use `chat_feature` to specify the key for a chat dataset. - To specify the train, valid, or test splits, set the corresponding `{train,valid,test}_split` argument. +You can specify a list of Hugging Face datasets with a list of records each +with the same structure as above. For example: + +```yaml +hf_dataset: + - name: "Open-Orca/OpenOrca" + train_split: "train[:90%]" + valid_split: "train[-10%:]" + prompt_feature: "question" + completion_feature: "response" + - name: "trl-lib/ultrafeedback_binarized" + train_split: "train[:90%]" + valid_split: "train[-10%:]" + chat_feature: "chosen" +``` + - Arguments specified in `config` will be passed as keyword arguments to [`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset). diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 43f508c3..abc5dfa9 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -94,6 +94,14 @@ def build_parser(): choices=["lora", "dora", "full"], help="Type of fine-tuning to perform: lora, dora, or full.", ) + + parser.add_argument( + "--mask-prompt", + action="store_true", + help="Mask the prompt in the loss when training", + default=False, + ) + parser.add_argument( "--num-layers", type=int, @@ -219,6 +227,7 @@ def train_model( build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate ) ) + # Train model train( model=model, diff --git a/llms/mlx_lm/tokenizer_utils.py b/llms/mlx_lm/tokenizer_utils.py index 1b5bdd77..de9d5324 100644 --- a/llms/mlx_lm/tokenizer_utils.py +++ b/llms/mlx_lm/tokenizer_utils.py @@ -1,5 +1,6 @@ import json from functools import partial +from typing import List from transformers import AutoTokenizer @@ -368,3 +369,8 @@ def load_tokenizer(model_path, tokenizer_config_extra={}, eos_token_ids=None): detokenizer_class, eos_token_ids=eos_token_ids, ) + + +def no_bos_or_eos(sequence: List, bos: int, eos: int) -> List: + removed_bos = sequence if sequence[0] != bos else sequence[1:] + return removed_bos[:-1] if removed_bos[-1] == eos else removed_bos diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 377e7cae..a6f3bd29 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -1,6 +1,8 @@ +import itertools import json +import types from pathlib import Path -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional from transformers import PreTrainedTokenizer @@ -34,14 +36,24 @@ class ChatDataset: https://platform.openai.com/docs/guides/fine-tuning/example-format """ - def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer): - self._data = [ - tokenizer.apply_chat_template( - d["messages"], - tools=d.get("tools", None), - ) - for d in data - ] + def __init__( + self, + data: List[Dict[str, str]], + tokenizer: PreTrainedTokenizer, + chat_key: str = "messages", + mask_prompt: bool = False, + ): + self._data = [] + for d in data: + messages = d[chat_key] + tools = d.get("tools", None) + tokens = tokenizer.apply_chat_template(messages, tools=tools) + if mask_prompt: + messages = messages[:-1] + offset = len(tokenizer.apply_chat_template(messages, tools=tools)) + self._data.append((tokens, offset)) + else: + self._data.append(tokens) def __getitem__(self, idx: int): return self._data[idx] @@ -63,16 +75,36 @@ class CompletionsDataset: tokenizer: PreTrainedTokenizer, prompt_key: str, completion_key: str, + mask_prompt: bool, ): - self._data = [ - tokenizer.apply_chat_template( + self._data = [] + for d in data: + tokens = tokenizer.apply_chat_template( [ {"role": "user", "content": d[prompt_key]}, {"role": "assistant", "content": d[completion_key]}, ], ) - for d in data - ] + if mask_prompt: + offset = len( + tokenizer.apply_chat_template( + [{"role": "user", "content": d[prompt_key]}] + ) + ) + self._data.append((tokens, offset)) + else: + self._data.append(tokens) + + def __getitem__(self, idx: int): + return self._data[idx] + + def __len__(self): + return len(self._data) + + +class ConcatenatedDataset: + def __init__(self, data: List[Any]): + self._data = list(itertools.chain(*data)) def __getitem__(self, idx: int): return self._data[idx] @@ -84,18 +116,26 @@ class CompletionsDataset: def create_dataset( data, tokenizer: PreTrainedTokenizer, - prompt_feature: Optional[str] = None, - completion_feature: Optional[str] = None, + config, ): - prompt_feature = prompt_feature or "prompt" - completion_feature = completion_feature or "completion" + mask_prompt = getattr(config, "mask_prompt", False) + prompt_feature = getattr(config, "prompt_feature", "prompt") + text_feature = getattr(config, "text_feature", "text") + completion_feature = getattr(config, "completion_feature", "completion") + chat_feature = getattr(config, "chat_feature", "messages") sample = data[0] - if "messages" in sample: - return ChatDataset(data, tokenizer) - elif prompt_feature in sample and completion_feature in sample: - return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature) - elif "text" in sample: - return Dataset(data, tokenizer) + if prompt_feature in sample and completion_feature in sample: + return CompletionsDataset( + data, tokenizer, prompt_feature, completion_feature, mask_prompt + ) + elif chat_feature in sample: + return ChatDataset( + data, tokenizer, chat_key=chat_feature, mask_prompt=mask_prompt + ) + elif text_feature in sample: + if mask_prompt: + raise ValueError("Prompt masking not supported for text dataset.") + return Dataset(data, tokenizer, text_key=text_feature) else: raise ValueError( "Unsupported data format, check the supported formats here:\n" @@ -106,15 +146,14 @@ def create_dataset( def load_local_dataset( data_path: Path, tokenizer: PreTrainedTokenizer, - prompt_feature: Optional[str] = None, - completion_feature: Optional[str] = None, + config, ): def load_subset(path): if not path.exists(): return [] with open(path, "r") as fid: data = [json.loads(l) for l in fid] - return create_dataset(data, tokenizer, prompt_feature, completion_feature) + return create_dataset(data, tokenizer, config) names = ("train", "valid", "test") train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names] @@ -124,8 +163,7 @@ def load_local_dataset( def load_hf_dataset( data_id: str, tokenizer: PreTrainedTokenizer, - prompt_feature: Optional[str] = None, - completion_feature: Optional[str] = None, + config, ): from datasets import exceptions, load_dataset @@ -136,9 +174,7 @@ def load_hf_dataset( train, valid, test = [ ( - create_dataset( - dataset[n], tokenizer, prompt_feature, completion_feature - ) + create_dataset(dataset[n], tokenizer, config) if n in dataset.keys() else [] ) @@ -154,42 +190,61 @@ def load_hf_dataset( def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer): import datasets - hf_args = args.hf_dataset - dataset_name = hf_args["name"] - print(f"Loading Hugging Face dataset {dataset_name}.") - text_feature = hf_args.get("text_feature") - prompt_feature = hf_args.get("prompt_feature") - completion_feature = hf_args.get("completion_feature") - - def create_hf_dataset(split: str = None): + def create_hf_dataset(dataset_name, config, split, hf_config): ds = datasets.load_dataset( dataset_name, split=split, - **hf_args.get("config", {}), + **hf_config, ) - if prompt_feature and completion_feature: - return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature) - elif text_feature: - return Dataset(ds, tokenizer, text_key=text_feature) - else: - raise ValueError( - "Specify either a prompt and completion feature or a text " - "feature for the Hugging Face dataset." + return create_dataset(ds, tokenizer, config) + + dataset_collection = args.hf_dataset + if isinstance(dataset_collection, dict): + dataset_collection = [dataset_collection] + + collection = [] + for ds in dataset_collection: + ds_name = ds["name"] + print(f"Loading Hugging Face dataset {ds_name}.") + ds["mask_prompt"] = getattr(args, "mask_prompt", False) + config = types.SimpleNamespace(**ds) + hf_config = ds.get("config", {}) + if args.train: + train_split = ds.get("train_split", "train[:80%]") + valid_split = ds.get("valid_split", "train[-10%:]") + train = create_hf_dataset( + ds_name, + config, + train_split, + hf_config, ) + valid = create_hf_dataset( + ds_name, + config, + valid_split, + hf_config, + ) + else: + train, valid = [], [] - if args.train: - train_split = hf_args.get("train_split", "train[:80%]") - valid_split = hf_args.get("valid_split", "train[-10%:]") - train = create_hf_dataset(split=train_split) - valid = create_hf_dataset(split=valid_split) - else: - train, valid = [], [] - if args.test: - test = create_hf_dataset(split=hf_args.get("test_split")) - else: - test = [] + if args.test: + test_split = ds.get("test_split") + test = create_hf_dataset( + ds_name, + config, + test_split, + hf_config, + ) + else: + test = [] - return train, valid, test + collection.append((train, valid, test)) + + if len(collection) == 1: + return collection[0] + + # Otherwise concatenate them + return tuple(map(ConcatenatedDataset, zip(*collection))) def load_dataset(args, tokenizer: PreTrainedTokenizer): @@ -197,18 +252,11 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer): train, valid, test = load_custom_hf_dataset(args, tokenizer) else: data_path = Path(args.data) - - prompt_feature = getattr(args, "prompt_feature", None) - completion_feature = getattr(args, "completion_feature", None) if data_path.exists(): - train, valid, test = load_local_dataset( - data_path, tokenizer, prompt_feature, completion_feature - ) + train, valid, test = load_local_dataset(data_path, tokenizer, args) else: print(f"Loading Hugging Face dataset {args.data}.") - train, valid, test = load_hf_dataset( - args.data, tokenizer, prompt_feature, completion_feature - ) + train, valid, test = load_hf_dataset(args.data, tokenizer, args) if args.train and len(train) == 0: raise ValueError( diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index bf84d066..d675f9b6 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -5,13 +5,16 @@ import shutil import time from dataclasses import dataclass, field from pathlib import Path -from typing import Union +from typing import List, Optional, Tuple import mlx.core as mx import mlx.nn as nn import numpy as np from mlx.nn.utils import average_gradients from mlx.utils import tree_flatten +from transformers import PreTrainedTokenizer + +from .datasets import CompletionsDataset def grad_checkpoint(layer): @@ -63,20 +66,30 @@ class TrainingArgs: ) -def default_loss(model, inputs, targets, lengths): +def default_loss(model, batch, lengths): + inputs = batch[:, :-1] + targets = batch[:, 1:] + logits = model(inputs) logits = logits.astype(mx.float32) - length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None] + steps = mx.arange(1, targets.shape[1] + 1) + mask = mx.logical_and(steps >= lengths[:, 0:1], steps <= lengths[:, 1:]) - ce = nn.losses.cross_entropy(logits, targets) * length_mask - ntoks = length_mask.sum() + ce = nn.losses.cross_entropy(logits, targets) * mask + ntoks = mask.sum() ce = ce.sum() / ntoks return ce, ntoks -def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): +def iterate_batches( + dataset, + tokenizer, + batch_size, + max_seq_length, + train=False, +): # Sort by length: idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx])) if len(dataset) < batch_size: @@ -101,6 +114,10 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False) indices = np.random.permutation(len(batch_idx)) for i in indices: batch = [dataset[j] for j in batch_idx[i]] + if len(batch[0]) == 2: + batch, offsets = zip(*batch) + else: + offsets = [0] * len(batch) lengths = [len(x) for x in batch] if max(lengths) > max_seq_length: print( @@ -123,8 +140,7 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False) truncated_length # Update lengths to match truncated lengths ) batch = mx.array(batch_arr) - - yield batch[:, :-1], batch[:, 1:], mx.array(lengths) + yield batch, mx.array(list(zip(offsets, lengths))) if not train: break diff --git a/llms/tests/test_datsets.py b/llms/tests/test_datsets.py index dd86d277..5edab8bf 100644 --- a/llms/tests/test_datsets.py +++ b/llms/tests/test_datsets.py @@ -78,14 +78,15 @@ class TestDatasets(unittest.TestCase): self.assertTrue(isinstance(train, datasets.ChatDataset)) def test_hf(self): + hf_args = { + "name": "billsum", + "prompt_feature": "text", + "completion_feature": "summary", + "train_split": "train[:2%]", + "valid_split": "train[-2%:]", + } args = types.SimpleNamespace( - hf_dataset={ - "name": "billsum", - "prompt_feature": "text", - "completion_feature": "summary", - "train_split": "train[:2%]", - "valid_split": "train[-2%:]", - }, + hf_dataset=hf_args, test=False, train=True, ) @@ -97,6 +98,16 @@ class TestDatasets(unittest.TestCase): self.assertTrue(len(valid[0]) > 0) self.assertEqual(len(test), 0) + args = types.SimpleNamespace( + hf_dataset=[hf_args, hf_args], + test=False, + train=True, + ) + train_double, valid_double, test_double = datasets.load_dataset(args, tokenizer) + self.assertEqual(2 * len(train), len(train_double)) + self.assertEqual(2 * len(valid), len(valid_double)) + self.assertEqual(2 * len(test), len(test_double)) + if __name__ == "__main__": unittest.main()