diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md index b20243ef..0f0baf52 100644 --- a/llms/mlx_lm/LORA.md +++ b/llms/mlx_lm/LORA.md @@ -136,14 +136,54 @@ correct format. For fine-tuning (`--train`), the data loader expects a `train.jsonl` and a `valid.jsonl` to be in the data directory. For evaluation (`--test`), the data -loader expects a `test.jsonl` in the data directory. Each line in the `*.jsonl` -file should look like: +loader expects a `test.jsonl` in the data directory. +Currently, `*.jsonl` files support three data formats: `chat`, +`completions`, and `text`. Here are three examples of these formats: + +`chat`: + +```jsonl +{"messages": [ + {"role": "system", "content": "You are a helpful assistant." }, + {"role": "user", "content": "Hello."}, + {"role": "assistant", "content": "How can I assistant you today."}, +]} ``` + +`completions`: + +```jsonl +{"prompt": "What is the capital of France?", "completion": "Paris."} +``` + +`text`: + +```jsonl {"text": "This is an example for the model."} ``` -Note, other keys will be ignored by the loader. +Note, the format is automatically determined by the dataset. Note also, keys in +each line not expected by the loader will be ignored. + +For the `chat` and `completions` formats, Hugging Face [chat +templates](https://huggingface.co/blog/chat-templates) are used. This applies +the model's chat template by default. If the model does not have a chat +template, then Hugging Face will use a default. For example, the final text in +the `chat` example above with Hugging Face's default template becomes: + +```text +<|im_start|>system +You are a helpful assistant.<|im_end|> +<|im_start|>user +Hello.<|im_end|> +<|im_start|>assistant +How can I assistant you today.<|im_end|> +``` + +If you are unsure of the format to use, the `chat` or `completions` are good to +start with. For custom requirements on the format of the dataset, use the +`text` format to assemble the content yourself. ## Memory Issues diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 615fb417..a31e973f 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -12,6 +12,7 @@ import numpy as np import yaml from mlx.utils import tree_flatten +from .tuner.datasets import load_dataset from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train from .tuner.utils import linear_to_lora_layers from .utils import load @@ -141,46 +142,6 @@ def build_parser(): return parser -class Dataset: - """ - Light-weight wrapper to hold lines from a jsonl file - """ - - def __init__(self, path: Path, key: str = "text"): - if not path.exists(): - self._data = None - else: - with open(path, "r") as fid: - self._data = [json.loads(l) for l in fid] - self._key = key - - def __getitem__(self, idx: int): - return self._data[idx][self._key] - - def __len__(self): - if self._data is None: - return 0 - return len(self._data) - - -def load_dataset(args): - names = ("train", "valid", "test") - train, valid, test = (Dataset(Path(args.data) / f"{n}.jsonl") for n in names) - if args.train and len(train) == 0: - raise ValueError( - "Training set not found or empty. Must provide training set for fine-tuning." - ) - if args.train and len(valid) == 0: - raise ValueError( - "Validation set not found or empty. Must provide validation set for fine-tuning." - ) - if args.test and len(test) == 0: - raise ValueError( - "Test set not found or empty. Must provide test set for evaluation." - ) - return train, valid, test - - def print_trainable_parameters(model): total_p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6 trainable_p = ( @@ -206,7 +167,7 @@ def run(args, training_callback: TrainingCallback = None): print_trainable_parameters(model) print("Loading datasets") - train_set, valid_set, test_set = load_dataset(args) + train_set, valid_set, test_set = load_dataset(args, tokenizer) # Resume training the given adapters. if args.resume_adapter_file is not None: diff --git a/llms/mlx_lm/requirements.txt b/llms/mlx_lm/requirements.txt index 518871ef..040fa864 100644 --- a/llms/mlx_lm/requirements.txt +++ b/llms/mlx_lm/requirements.txt @@ -3,3 +3,4 @@ numpy transformers>=4.38.0 protobuf pyyaml +jinja2 diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py new file mode 100644 index 00000000..e5776160 --- /dev/null +++ b/llms/mlx_lm/tuner/datasets.py @@ -0,0 +1,104 @@ +import json +from pathlib import Path + +from transformers import PreTrainedTokenizer + + +class Dataset: + """ + Light-weight wrapper to hold lines from a jsonl file + """ + + def __init__(self, path: Path): + with open(path, "r") as fid: + self._data = [json.loads(l) for l in fid] + + def __getitem__(self, idx: int): + return self._data[idx]["text"] + + def __len__(self): + if self._data is None: + return 0 + return len(self._data) + + +class ChatDataset(Dataset): + """ + A dataset for chat data in the format of {"messages": [...]} + https://platform.openai.com/docs/guides/fine-tuning/example-format + """ + + def __init__(self, path: Path, tokenizer: PreTrainedTokenizer): + super().__init__(path) + self._tokenizer = tokenizer + + def __getitem__(self, idx: int): + messages = self._data[idx]["messages"] + text = self._tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + return text + + +class CompletionsDataset(Dataset): + """ + A dataset for prompt-completion data in the format of {"prompt": ..., "completion": ...} + https://platform.openai.com/docs/guides/fine-tuning/example-format + """ + + def __init__(self, path: Path, tokenizer: PreTrainedTokenizer): + super().__init__(path) + self._tokenizer = tokenizer + + def __getitem__(self, idx: int): + data = self._data[idx] + text = self._tokenizer.apply_chat_template( + [ + {"role": "user", "content": data["prompt"]}, + {"role": "assistant", "content": data["completion"]}, + ], + tokenize=False, + add_generation_prompt=True, + ) + return text + + +def create_dataset(path: Path, tokenizer: PreTrainedTokenizer = None): + # Return empty dataset for non-existent paths + if not path.exists(): + return [] + with open(path, "r") as fid: + first_line = next(fid) + first_obj = json.loads(first_line) + if "messages" in first_obj: + return ChatDataset(path, tokenizer) + elif "prompt" in first_obj and "completion" in first_obj: + return CompletionsDataset(path, tokenizer) + elif "text" in first_obj: + return Dataset(path) + else: + raise ValueError( + "Unsupported data format, check the supported formats here:\n" + "https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md#data." + ) + + +def load_dataset(args, tokenizer: PreTrainedTokenizer): + names = ("train", "valid", "test") + data_path = Path(args.data) + train, valid, test = [ + create_dataset(data_path / f"{n}.jsonl", tokenizer) for n in names + ] + if args.train and len(train) == 0: + raise ValueError( + "Training set not found or empty. Must provide training set for fine-tuning." + ) + if args.train and len(valid) == 0: + raise ValueError( + "Validation set not found or empty. Must provide validation set for fine-tuning." + ) + if args.test and len(test) == 0: + raise ValueError( + "Test set not found or empty. Must provide test set for evaluation." + ) + return train, valid, test diff --git a/llms/tests/test_datsets.py b/llms/tests/test_datsets.py new file mode 100644 index 00000000..8d8c01a5 --- /dev/null +++ b/llms/tests/test_datsets.py @@ -0,0 +1,81 @@ +# Copyright © 2024 Apple Inc. + +import json +import os +import tempfile +import types +import unittest + +from mlx_lm.tuner import datasets +from transformers import AutoTokenizer + +HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" + + +class TestDatasets(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.test_dir_fid = tempfile.TemporaryDirectory() + cls.test_dir = cls.test_dir_fid.name + if not os.path.isdir(cls.test_dir): + os.mkdir(cls.test_dir_fid.name) + + @classmethod + def tearDownClass(cls): + cls.test_dir_fid.cleanup() + + def save_data(self, data): + for ds in ["train", "valid"]: + with open(os.path.join(self.test_dir, f"{ds}.jsonl"), "w") as fid: + for l in data: + json.dump(l, fid) + fid.write("\n") + + def test_text(self): + data = {"text": "This is an example for the model."} + self.save_data(4 * [data]) + args = types.SimpleNamespace(train=True, test=False, data=self.test_dir) + train, valid, test = datasets.load_dataset(args, None) + self.assertEqual(len(train), 4) + self.assertEqual(len(valid), 4) + self.assertEqual(len(test), 0) + self.assertTrue(len(train[0]) > 0) + self.assertTrue(len(valid[0]) > 0) + self.assertTrue(isinstance(train, datasets.Dataset)) + + def test_completions(self): + data = {"prompt": "What is the capital of France?", "completion": "Paris."} + self.save_data(4 * [data]) + args = types.SimpleNamespace(train=True, test=False, data=self.test_dir) + tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_PATH) + train, valid, test = datasets.load_dataset(args, tokenizer) + self.assertEqual(len(train), 4) + self.assertEqual(len(valid), 4) + self.assertEqual(len(test), 0) + self.assertTrue(len(train[0]) > 0) + self.assertTrue(len(valid[0]) > 0) + self.assertTrue(isinstance(train, datasets.CompletionsDataset)) + + def test_chat(self): + data = { + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello."}, + {"role": "assistant", "content": "How can I assistant you today."}, + ] + } + self.save_data(4 * [data]) + args = types.SimpleNamespace(train=True, test=False, data=self.test_dir) + tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_PATH) + train, valid, test = datasets.load_dataset(args, tokenizer) + self.assertEqual(len(train), 4) + self.assertEqual(len(valid), 4) + self.assertEqual(len(test), 0) + self.assertTrue(len(train[0]) > 0) + self.assertTrue(len(valid[0]) > 0) + self.assertTrue(isinstance(train, datasets.ChatDataset)) + + +if __name__ == "__main__": + unittest.main()