mirror of
				https://github.com/ml-explore/mlx-examples.git
				synced 2025-10-23 05:58:07 +08:00 
			
		
		
		
	Completion only fine-tuning of instruction models with collections of HF datasets (#1103)
- Optional completion only fine-tuning with `--mask-prompt` - Collections of Hugging Face datasets --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
		| @@ -76,6 +76,14 @@ You can specify the output location with `--adapter-path`. | ||||
| You can resume fine-tuning with an existing adapter with | ||||
| `--resume-adapter-file <path_to_adapters.safetensors>`. | ||||
|  | ||||
| #### Prompt Masking | ||||
|  | ||||
| The default training computes a loss for every token in the sample. You can | ||||
| ignore the prompt and compute loss for just the completion by passing | ||||
| `--mask-prompt`. Note this is only supported for `chat` and `completion` | ||||
| datasets. For `chat` datasets the final message in the message list is | ||||
| considered the completion. See the [dataset section](#Data) for more details.  | ||||
|  | ||||
| ### Evaluate | ||||
|  | ||||
| To compute test set perplexity use: | ||||
| @@ -290,11 +298,27 @@ hf_dataset: | ||||
|  | ||||
| - Use `prompt_feature` and `completion_feature` to specify keys for a | ||||
|   `completions` dataset. Use `text_feature` to specify the key for a `text` | ||||
|   dataset.  | ||||
|   dataset. Use `chat_feature` to specify the key for a chat dataset. | ||||
|  | ||||
| - To specify the train, valid, or test splits, set the corresponding | ||||
|   `{train,valid,test}_split` argument.  | ||||
|  | ||||
| You can specify a list of Hugging Face datasets with a list of records each | ||||
| with the same structure as above. For example: | ||||
|  | ||||
| ```yaml | ||||
| hf_dataset: | ||||
|   - name: "Open-Orca/OpenOrca" | ||||
|     train_split: "train[:90%]" | ||||
|     valid_split: "train[-10%:]" | ||||
|     prompt_feature: "question" | ||||
|     completion_feature: "response" | ||||
|   - name: "trl-lib/ultrafeedback_binarized" | ||||
|     train_split: "train[:90%]" | ||||
|     valid_split: "train[-10%:]" | ||||
|     chat_feature: "chosen" | ||||
| ``` | ||||
|  | ||||
| - Arguments specified in `config` will be passed as keyword arguments to | ||||
|   [`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset). | ||||
|  | ||||
|   | ||||
| @@ -94,6 +94,14 @@ def build_parser(): | ||||
|         choices=["lora", "dora", "full"], | ||||
|         help="Type of fine-tuning to perform: lora, dora, or full.", | ||||
|     ) | ||||
|  | ||||
|     parser.add_argument( | ||||
|         "--mask-prompt", | ||||
|         action="store_true", | ||||
|         help="Mask the prompt in the loss when training", | ||||
|         default=False, | ||||
|     ) | ||||
|  | ||||
|     parser.add_argument( | ||||
|         "--num-layers", | ||||
|         type=int, | ||||
| @@ -219,6 +227,7 @@ def train_model( | ||||
|             build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate | ||||
|         ) | ||||
|     ) | ||||
|  | ||||
|     # Train model | ||||
|     train( | ||||
|         model=model, | ||||
|   | ||||
| @@ -1,5 +1,6 @@ | ||||
| import json | ||||
| from functools import partial | ||||
| from typing import List | ||||
|  | ||||
| from transformers import AutoTokenizer | ||||
|  | ||||
| @@ -368,3 +369,8 @@ def load_tokenizer(model_path, tokenizer_config_extra={}, eos_token_ids=None): | ||||
|         detokenizer_class, | ||||
|         eos_token_ids=eos_token_ids, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def no_bos_or_eos(sequence: List, bos: int, eos: int) -> List: | ||||
|     removed_bos = sequence if sequence[0] != bos else sequence[1:] | ||||
|     return removed_bos[:-1] if removed_bos[-1] == eos else removed_bos | ||||
|   | ||||
| @@ -1,6 +1,8 @@ | ||||
| import itertools | ||||
| import json | ||||
| import types | ||||
| from pathlib import Path | ||||
| from typing import Dict, List, Optional | ||||
| from typing import Any, Dict, List, Optional | ||||
|  | ||||
| from transformers import PreTrainedTokenizer | ||||
|  | ||||
| @@ -34,14 +36,24 @@ class ChatDataset: | ||||
|     https://platform.openai.com/docs/guides/fine-tuning/example-format | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer): | ||||
|         self._data = [ | ||||
|             tokenizer.apply_chat_template( | ||||
|                 d["messages"], | ||||
|                 tools=d.get("tools", None), | ||||
|             ) | ||||
|             for d in data | ||||
|         ] | ||||
|     def __init__( | ||||
|         self, | ||||
|         data: List[Dict[str, str]], | ||||
|         tokenizer: PreTrainedTokenizer, | ||||
|         chat_key: str = "messages", | ||||
|         mask_prompt: bool = False, | ||||
|     ): | ||||
|         self._data = [] | ||||
|         for d in data: | ||||
|             messages = d[chat_key] | ||||
|             tools = d.get("tools", None) | ||||
|             tokens = tokenizer.apply_chat_template(messages, tools=tools) | ||||
|             if mask_prompt: | ||||
|                 messages = messages[:-1] | ||||
|                 offset = len(tokenizer.apply_chat_template(messages, tools=tools)) | ||||
|                 self._data.append((tokens, offset)) | ||||
|             else: | ||||
|                 self._data.append(tokens) | ||||
|  | ||||
|     def __getitem__(self, idx: int): | ||||
|         return self._data[idx] | ||||
| @@ -63,16 +75,36 @@ class CompletionsDataset: | ||||
|         tokenizer: PreTrainedTokenizer, | ||||
|         prompt_key: str, | ||||
|         completion_key: str, | ||||
|         mask_prompt: bool, | ||||
|     ): | ||||
|         self._data = [ | ||||
|             tokenizer.apply_chat_template( | ||||
|         self._data = [] | ||||
|         for d in data: | ||||
|             tokens = tokenizer.apply_chat_template( | ||||
|                 [ | ||||
|                     {"role": "user", "content": d[prompt_key]}, | ||||
|                     {"role": "assistant", "content": d[completion_key]}, | ||||
|                 ], | ||||
|             ) | ||||
|             for d in data | ||||
|         ] | ||||
|             if mask_prompt: | ||||
|                 offset = len( | ||||
|                     tokenizer.apply_chat_template( | ||||
|                         [{"role": "user", "content": d[prompt_key]}] | ||||
|                     ) | ||||
|                 ) | ||||
|                 self._data.append((tokens, offset)) | ||||
|             else: | ||||
|                 self._data.append(tokens) | ||||
|  | ||||
|     def __getitem__(self, idx: int): | ||||
|         return self._data[idx] | ||||
|  | ||||
|     def __len__(self): | ||||
|         return len(self._data) | ||||
|  | ||||
|  | ||||
| class ConcatenatedDataset: | ||||
|     def __init__(self, data: List[Any]): | ||||
|         self._data = list(itertools.chain(*data)) | ||||
|  | ||||
|     def __getitem__(self, idx: int): | ||||
|         return self._data[idx] | ||||
| @@ -84,18 +116,26 @@ class CompletionsDataset: | ||||
| def create_dataset( | ||||
|     data, | ||||
|     tokenizer: PreTrainedTokenizer, | ||||
|     prompt_feature: Optional[str] = None, | ||||
|     completion_feature: Optional[str] = None, | ||||
|     config, | ||||
| ): | ||||
|     prompt_feature = prompt_feature or "prompt" | ||||
|     completion_feature = completion_feature or "completion" | ||||
|     mask_prompt = getattr(config, "mask_prompt", False) | ||||
|     prompt_feature = getattr(config, "prompt_feature", "prompt") | ||||
|     text_feature = getattr(config, "text_feature", "text") | ||||
|     completion_feature = getattr(config, "completion_feature", "completion") | ||||
|     chat_feature = getattr(config, "chat_feature", "messages") | ||||
|     sample = data[0] | ||||
|     if "messages" in sample: | ||||
|         return ChatDataset(data, tokenizer) | ||||
|     elif prompt_feature in sample and completion_feature in sample: | ||||
|         return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature) | ||||
|     elif "text" in sample: | ||||
|         return Dataset(data, tokenizer) | ||||
|     if prompt_feature in sample and completion_feature in sample: | ||||
|         return CompletionsDataset( | ||||
|             data, tokenizer, prompt_feature, completion_feature, mask_prompt | ||||
|         ) | ||||
|     elif chat_feature in sample: | ||||
|         return ChatDataset( | ||||
|             data, tokenizer, chat_key=chat_feature, mask_prompt=mask_prompt | ||||
|         ) | ||||
|     elif text_feature in sample: | ||||
|         if mask_prompt: | ||||
|             raise ValueError("Prompt masking not supported for text dataset.") | ||||
|         return Dataset(data, tokenizer, text_key=text_feature) | ||||
|     else: | ||||
|         raise ValueError( | ||||
|             "Unsupported data format, check the supported formats here:\n" | ||||
| @@ -106,15 +146,14 @@ def create_dataset( | ||||
| def load_local_dataset( | ||||
|     data_path: Path, | ||||
|     tokenizer: PreTrainedTokenizer, | ||||
|     prompt_feature: Optional[str] = None, | ||||
|     completion_feature: Optional[str] = None, | ||||
|     config, | ||||
| ): | ||||
|     def load_subset(path): | ||||
|         if not path.exists(): | ||||
|             return [] | ||||
|         with open(path, "r") as fid: | ||||
|             data = [json.loads(l) for l in fid] | ||||
|         return create_dataset(data, tokenizer, prompt_feature, completion_feature) | ||||
|         return create_dataset(data, tokenizer, config) | ||||
|  | ||||
|     names = ("train", "valid", "test") | ||||
|     train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names] | ||||
| @@ -124,8 +163,7 @@ def load_local_dataset( | ||||
| def load_hf_dataset( | ||||
|     data_id: str, | ||||
|     tokenizer: PreTrainedTokenizer, | ||||
|     prompt_feature: Optional[str] = None, | ||||
|     completion_feature: Optional[str] = None, | ||||
|     config, | ||||
| ): | ||||
|     from datasets import exceptions, load_dataset | ||||
|  | ||||
| @@ -136,9 +174,7 @@ def load_hf_dataset( | ||||
|  | ||||
|         train, valid, test = [ | ||||
|             ( | ||||
|                 create_dataset( | ||||
|                     dataset[n], tokenizer, prompt_feature, completion_feature | ||||
|                 ) | ||||
|                 create_dataset(dataset[n], tokenizer, config) | ||||
|                 if n in dataset.keys() | ||||
|                 else [] | ||||
|             ) | ||||
| @@ -154,42 +190,61 @@ def load_hf_dataset( | ||||
| def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer): | ||||
|     import datasets | ||||
|  | ||||
|     hf_args = args.hf_dataset | ||||
|     dataset_name = hf_args["name"] | ||||
|     print(f"Loading Hugging Face dataset {dataset_name}.") | ||||
|     text_feature = hf_args.get("text_feature") | ||||
|     prompt_feature = hf_args.get("prompt_feature") | ||||
|     completion_feature = hf_args.get("completion_feature") | ||||
|  | ||||
|     def create_hf_dataset(split: str = None): | ||||
|     def create_hf_dataset(dataset_name, config, split, hf_config): | ||||
|         ds = datasets.load_dataset( | ||||
|             dataset_name, | ||||
|             split=split, | ||||
|             **hf_args.get("config", {}), | ||||
|         ) | ||||
|         if prompt_feature and completion_feature: | ||||
|             return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature) | ||||
|         elif text_feature: | ||||
|             return Dataset(ds, tokenizer, text_key=text_feature) | ||||
|         else: | ||||
|             raise ValueError( | ||||
|                 "Specify either a prompt and completion feature or a text " | ||||
|                 "feature for the Hugging Face dataset." | ||||
|             **hf_config, | ||||
|         ) | ||||
|         return create_dataset(ds, tokenizer, config) | ||||
|  | ||||
|     dataset_collection = args.hf_dataset | ||||
|     if isinstance(dataset_collection, dict): | ||||
|         dataset_collection = [dataset_collection] | ||||
|  | ||||
|     collection = [] | ||||
|     for ds in dataset_collection: | ||||
|         ds_name = ds["name"] | ||||
|         print(f"Loading Hugging Face dataset {ds_name}.") | ||||
|         ds["mask_prompt"] = getattr(args, "mask_prompt", False) | ||||
|         config = types.SimpleNamespace(**ds) | ||||
|         hf_config = ds.get("config", {}) | ||||
|         if args.train: | ||||
|         train_split = hf_args.get("train_split", "train[:80%]") | ||||
|         valid_split = hf_args.get("valid_split", "train[-10%:]") | ||||
|         train = create_hf_dataset(split=train_split) | ||||
|         valid = create_hf_dataset(split=valid_split) | ||||
|             train_split = ds.get("train_split", "train[:80%]") | ||||
|             valid_split = ds.get("valid_split", "train[-10%:]") | ||||
|             train = create_hf_dataset( | ||||
|                 ds_name, | ||||
|                 config, | ||||
|                 train_split, | ||||
|                 hf_config, | ||||
|             ) | ||||
|             valid = create_hf_dataset( | ||||
|                 ds_name, | ||||
|                 config, | ||||
|                 valid_split, | ||||
|                 hf_config, | ||||
|             ) | ||||
|         else: | ||||
|             train, valid = [], [] | ||||
|  | ||||
|         if args.test: | ||||
|         test = create_hf_dataset(split=hf_args.get("test_split")) | ||||
|             test_split = ds.get("test_split") | ||||
|             test = create_hf_dataset( | ||||
|                 ds_name, | ||||
|                 config, | ||||
|                 test_split, | ||||
|                 hf_config, | ||||
|             ) | ||||
|         else: | ||||
|             test = [] | ||||
|  | ||||
|     return train, valid, test | ||||
|         collection.append((train, valid, test)) | ||||
|  | ||||
|     if len(collection) == 1: | ||||
|         return collection[0] | ||||
|  | ||||
|     # Otherwise concatenate them | ||||
|     return tuple(map(ConcatenatedDataset, zip(*collection))) | ||||
|  | ||||
|  | ||||
| def load_dataset(args, tokenizer: PreTrainedTokenizer): | ||||
| @@ -197,18 +252,11 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer): | ||||
|         train, valid, test = load_custom_hf_dataset(args, tokenizer) | ||||
|     else: | ||||
|         data_path = Path(args.data) | ||||
|  | ||||
|         prompt_feature = getattr(args, "prompt_feature", None) | ||||
|         completion_feature = getattr(args, "completion_feature", None) | ||||
|         if data_path.exists(): | ||||
|             train, valid, test = load_local_dataset( | ||||
|                 data_path, tokenizer, prompt_feature, completion_feature | ||||
|             ) | ||||
|             train, valid, test = load_local_dataset(data_path, tokenizer, args) | ||||
|         else: | ||||
|             print(f"Loading Hugging Face dataset {args.data}.") | ||||
|             train, valid, test = load_hf_dataset( | ||||
|                 args.data, tokenizer, prompt_feature, completion_feature | ||||
|             ) | ||||
|             train, valid, test = load_hf_dataset(args.data, tokenizer, args) | ||||
|  | ||||
|     if args.train and len(train) == 0: | ||||
|         raise ValueError( | ||||
|   | ||||
| @@ -5,13 +5,16 @@ import shutil | ||||
| import time | ||||
| from dataclasses import dataclass, field | ||||
| from pathlib import Path | ||||
| from typing import Union | ||||
| from typing import List, Optional, Tuple | ||||
|  | ||||
| import mlx.core as mx | ||||
| import mlx.nn as nn | ||||
| import numpy as np | ||||
| from mlx.nn.utils import average_gradients | ||||
| from mlx.utils import tree_flatten | ||||
| from transformers import PreTrainedTokenizer | ||||
|  | ||||
| from .datasets import CompletionsDataset | ||||
|  | ||||
|  | ||||
| def grad_checkpoint(layer): | ||||
| @@ -63,20 +66,30 @@ class TrainingArgs: | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def default_loss(model, inputs, targets, lengths): | ||||
| def default_loss(model, batch, lengths): | ||||
|     inputs = batch[:, :-1] | ||||
|     targets = batch[:, 1:] | ||||
|  | ||||
|     logits = model(inputs) | ||||
|     logits = logits.astype(mx.float32) | ||||
|  | ||||
|     length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None] | ||||
|     steps = mx.arange(1, targets.shape[1] + 1) | ||||
|     mask = mx.logical_and(steps >= lengths[:, 0:1], steps <= lengths[:, 1:]) | ||||
|  | ||||
|     ce = nn.losses.cross_entropy(logits, targets) * length_mask | ||||
|     ntoks = length_mask.sum() | ||||
|     ce = nn.losses.cross_entropy(logits, targets) * mask | ||||
|     ntoks = mask.sum() | ||||
|     ce = ce.sum() / ntoks | ||||
|  | ||||
|     return ce, ntoks | ||||
|  | ||||
|  | ||||
| def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): | ||||
| def iterate_batches( | ||||
|     dataset, | ||||
|     tokenizer, | ||||
|     batch_size, | ||||
|     max_seq_length, | ||||
|     train=False, | ||||
| ): | ||||
|     # Sort by length: | ||||
|     idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx])) | ||||
|     if len(dataset) < batch_size: | ||||
| @@ -101,6 +114,10 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False) | ||||
|         indices = np.random.permutation(len(batch_idx)) | ||||
|         for i in indices: | ||||
|             batch = [dataset[j] for j in batch_idx[i]] | ||||
|             if len(batch[0]) == 2: | ||||
|                 batch, offsets = zip(*batch) | ||||
|             else: | ||||
|                 offsets = [0] * len(batch) | ||||
|             lengths = [len(x) for x in batch] | ||||
|             if max(lengths) > max_seq_length: | ||||
|                 print( | ||||
| @@ -123,8 +140,7 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False) | ||||
|                     truncated_length  # Update lengths to match truncated lengths | ||||
|                 ) | ||||
|             batch = mx.array(batch_arr) | ||||
|  | ||||
|             yield batch[:, :-1], batch[:, 1:], mx.array(lengths) | ||||
|             yield batch, mx.array(list(zip(offsets, lengths))) | ||||
|  | ||||
|         if not train: | ||||
|             break | ||||
|   | ||||
| @@ -78,14 +78,15 @@ class TestDatasets(unittest.TestCase): | ||||
|         self.assertTrue(isinstance(train, datasets.ChatDataset)) | ||||
|  | ||||
|     def test_hf(self): | ||||
|         args = types.SimpleNamespace( | ||||
|             hf_dataset={ | ||||
|         hf_args = { | ||||
|             "name": "billsum", | ||||
|             "prompt_feature": "text", | ||||
|             "completion_feature": "summary", | ||||
|             "train_split": "train[:2%]", | ||||
|             "valid_split": "train[-2%:]", | ||||
|             }, | ||||
|         } | ||||
|         args = types.SimpleNamespace( | ||||
|             hf_dataset=hf_args, | ||||
|             test=False, | ||||
|             train=True, | ||||
|         ) | ||||
| @@ -97,6 +98,16 @@ class TestDatasets(unittest.TestCase): | ||||
|         self.assertTrue(len(valid[0]) > 0) | ||||
|         self.assertEqual(len(test), 0) | ||||
|  | ||||
|         args = types.SimpleNamespace( | ||||
|             hf_dataset=[hf_args, hf_args], | ||||
|             test=False, | ||||
|             train=True, | ||||
|         ) | ||||
|         train_double, valid_double, test_double = datasets.load_dataset(args, tokenizer) | ||||
|         self.assertEqual(2 * len(train), len(train_double)) | ||||
|         self.assertEqual(2 * len(valid), len(valid_double)) | ||||
|         self.assertEqual(2 * len(test), len(test_double)) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Chime Ogbuji
					Chime Ogbuji