From 0228c46434157adaa48b44f9a227d2bb93354dc3 Mon Sep 17 00:00:00 2001 From: Chime Ogbuji Date: Mon, 13 Jan 2025 13:01:18 -0500 Subject: [PATCH] Custom local dataset features (#1085) * Generalize prompt_feature and completion_feature for use in local datasets to facilitate compatibility with many other training dataset formats. * Persist configured prompt/completion key * rebase + nits --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/LORA.md | 17 +++++++++-- llms/mlx_lm/tuner/datasets.py | 55 ++++++++++++++++++++++++++--------- 2 files changed, 56 insertions(+), 16 deletions(-) diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md index 15676360..9eac9d7f 100644 --- a/llms/mlx_lm/LORA.md +++ b/llms/mlx_lm/LORA.md @@ -241,14 +241,25 @@ Refer to the documentation for the model you are fine-tuning for more details. {"prompt": "What is the capital of France?", "completion": "Paris."} ``` +For the `completions` data format, a different key can be used for the prompt +and completion by specifying the following in the YAML config: + +```yaml +prompt_feature: "input" +completion_feature: "output" +``` + +Here, `"input"` is the expected key instead of the default `"prompt"`, and +`"output"` is the expected key instead of `"completion"`. + `text`: ```jsonl {"text": "This is an example for the model."} ``` -Note, the format is automatically determined by the dataset. Note also, keys in -each line not expected by the loader will be ignored. +Note, the format is automatically determined by the dataset. Note also, keys +in each line not expected by the loader will be ignored. > [!NOTE] > Each example in the datasets must be on a single line. Do not put more than @@ -270,7 +281,7 @@ Otherwise, provide a mapping of keys in the dataset to the features MLX LM expects. Use a YAML config to specify the Hugging Face dataset arguments. For example: -``` +```yaml hf_dataset: name: "billsum" prompt_feature: "text" diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index fa848f47..1b09c7e2 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -1,6 +1,6 @@ import json from pathlib import Path -from typing import Dict, List +from typing import Dict, List, Optional from transformers import PreTrainedTokenizer @@ -61,8 +61,8 @@ class CompletionsDataset: self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer, - prompt_key: str = "prompt", - completion_key: str = "completion", + prompt_key: str, + completion_key: str, ): self._data = [ tokenizer.apply_chat_template( @@ -81,13 +81,19 @@ class CompletionsDataset: return len(self._data) -def create_dataset(data, tokenizer: PreTrainedTokenizer): +def create_dataset( + data, + tokenizer: PreTrainedTokenizer, + prompt_feature: Optional[str] = None, + completion_feature: Optional[str] = None, +): + prompt_feature = prompt_feature or "prompt" + completion_feature = completion_feature or "completion" sample = data[0] - if "messages" in sample: return ChatDataset(data, tokenizer) - elif "prompt" in sample and "completion" in sample: - return CompletionsDataset(data, tokenizer) + elif prompt_feature in sample and completion_feature in sample: + return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature) elif "text" in sample: return Dataset(data, tokenizer) else: @@ -97,20 +103,30 @@ def create_dataset(data, tokenizer: PreTrainedTokenizer): ) -def load_local_dataset(data_path: Path, tokenizer: PreTrainedTokenizer): +def load_local_dataset( + data_path: Path, + tokenizer: PreTrainedTokenizer, + prompt_feature: Optional[str] = None, + completion_feature: Optional[str] = None, +): def load_subset(path): if not path.exists(): return [] with open(path, "r") as fid: data = [json.loads(l) for l in fid] - return create_dataset(data, tokenizer) + return create_dataset(data, tokenizer, prompt_feature, completion_feature) names = ("train", "valid", "test") train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names] return train, valid, test -def load_hf_dataset(data_id: str, tokenizer: PreTrainedTokenizer): +def load_hf_dataset( + data_id: str, + tokenizer: PreTrainedTokenizer, + prompt_feature: Optional[str] = None, + completion_feature: Optional[str] = None, +): from datasets import exceptions, load_dataset try: @@ -119,7 +135,13 @@ def load_hf_dataset(data_id: str, tokenizer: PreTrainedTokenizer): names = ("train", "valid", "test") train, valid, test = [ - create_dataset(dataset[n], tokenizer) if n in dataset.keys() else [] + ( + create_dataset( + dataset[n], tokenizer, prompt_feature, completion_feature + ) + if n in dataset.keys() + else [] + ) for n in names ] @@ -175,11 +197,18 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer): train, valid, test = load_custom_hf_dataset(args, tokenizer) else: data_path = Path(args.data) + + prompt_feature = getattr(args, "prompt_feature", None) + completion_feature = getattr(args, "completion_feature", None) if data_path.exists(): - train, valid, test = load_local_dataset(data_path, tokenizer) + train, valid, test = load_local_dataset( + data_path, tokenizer, prompt_feature, completion_feature + ) else: print(f"Loading Hugging Face dataset {args.data}.") - train, valid, test = load_hf_dataset(args.data, tokenizer) + train, valid, test = load_hf_dataset( + args.data, tokenizer, prompt_feature, completion_feature + ) if args.train and len(train) == 0: raise ValueError(