From db9898d10484d7680c597c5c6c7db10fa8a5ba6c Mon Sep 17 00:00:00 2001 From: Chime Ogbuji Date: Sat, 2 Nov 2024 18:05:27 -0400 Subject: [PATCH] Generalize prompt_feature and completion_feature for use in local datasets to facilitate compatibility with many other training dataset formats. --- llms/mlx_lm/LORA.md | 16 ++++++++++--- llms/mlx_lm/lora.py | 2 ++ llms/mlx_lm/tuner/datasets.py | 44 ++++++++++++++++++++++++++++------- 3 files changed, 51 insertions(+), 11 deletions(-) diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md index 15676360..e1a86325 100644 --- a/llms/mlx_lm/LORA.md +++ b/llms/mlx_lm/LORA.md @@ -247,8 +247,18 @@ Refer to the documentation for the model you are fine-tuning for more details. {"text": "This is an example for the model."} ``` -Note, the format is automatically determined by the dataset. Note also, keys in -each line not expected by the loader will be ignored. +Note, the format is automatically determined by the dataset. + +For the completion data format, a different key can be used for the _prompt_ and for the _completion_ by specifying +the following, for example, in the YAML config: + +```yaml +prompt_feature: "input" +completion_feature: "output" +``` + +Here, `input` is now the expected key instead of "prompt" and `output` is the expected key instead of "completion". +Note also, keys in each line not expected by the loader will be ignored. > [!NOTE] > Each example in the datasets must be on a single line. Do not put more than @@ -270,7 +280,7 @@ Otherwise, provide a mapping of keys in the dataset to the features MLX LM expects. Use a YAML config to specify the Hugging Face dataset arguments. For example: -``` +```yaml hf_dataset: name: "billsum" prompt_feature: "text" diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 4d050bd5..41a618b1 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -61,6 +61,8 @@ CONFIG_DEFAULTS = { "config": None, "grad_checkpoint": False, "lr_schedule": None, + "prompt_feature": "prompt", + "completion_feature": "completion", "lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}, } diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index fa848f47..692bdb5c 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -81,12 +81,20 @@ class CompletionsDataset: return len(self._data) -def create_dataset(data, tokenizer: PreTrainedTokenizer): +<<<<<<< HEAD +def create_dataset( + data, + tokenizer: PreTrainedTokenizer, + prompt_feature: Optional[str] = None, + completion_feature: Optional[str] = None, +): sample = data[0] + prompt_feature = prompt_feature or "prompt" + completion_feature = completion_feature or "completion" if "messages" in sample: return ChatDataset(data, tokenizer) - elif "prompt" in sample and "completion" in sample: + elif prompt_feature in sample and completion_feature in sample: return CompletionsDataset(data, tokenizer) elif "text" in sample: return Dataset(data, tokenizer) @@ -97,20 +105,30 @@ def create_dataset(data, tokenizer: PreTrainedTokenizer): ) -def load_local_dataset(data_path: Path, tokenizer: PreTrainedTokenizer): +def load_local_dataset( + data_path: Path, + tokenizer: PreTrainedTokenizer, + prompt_feature: str = None, + completion_feature: str = None, +): def load_subset(path): if not path.exists(): return [] with open(path, "r") as fid: data = [json.loads(l) for l in fid] - return create_dataset(data, tokenizer) + return create_dataset(data, tokenizer, prompt_feature, completion_feature) names = ("train", "valid", "test") train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names] return train, valid, test -def load_hf_dataset(data_id: str, tokenizer: PreTrainedTokenizer): +def load_hf_dataset( + data_id: str, + tokenizer: PreTrainedTokenizer, + prompt_feature: str = None, + completion_feature: str = None, +): from datasets import exceptions, load_dataset try: @@ -119,7 +137,13 @@ def load_hf_dataset(data_id: str, tokenizer: PreTrainedTokenizer): names = ("train", "valid", "test") train, valid, test = [ - create_dataset(dataset[n], tokenizer) if n in dataset.keys() else [] + ( + create_dataset( + dataset[n], tokenizer, prompt_feature, completion_feature + ) + if n in dataset.keys() + else [] + ) for n in names ] @@ -176,10 +200,14 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer): else: data_path = Path(args.data) if data_path.exists(): - train, valid, test = load_local_dataset(data_path, tokenizer) + train, valid, test = load_local_dataset( + data_path, tokenizer, args.prompt_feature, args.completion_feature + ) else: print(f"Loading Hugging Face dataset {args.data}.") - train, valid, test = load_hf_dataset(args.data, tokenizer) + train, valid, test = load_hf_dataset( + args.data, tokenizer, args.prompt_feature, args.completion_feature + ) if args.train and len(train) == 0: raise ValueError(