From 7499720b099bbf12fa73422cde68d7a990384ab1 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 13 Jan 2025 09:43:23 -0800 Subject: [PATCH] rebase + nits --- llms/mlx_lm/LORA.md | 25 +++++++++++++------------ llms/mlx_lm/lora.py | 2 -- llms/mlx_lm/tuner/datasets.py | 25 +++++++++++++------------ 3 files changed, 26 insertions(+), 26 deletions(-) diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md index e1a86325..9eac9d7f 100644 --- a/llms/mlx_lm/LORA.md +++ b/llms/mlx_lm/LORA.md @@ -241,24 +241,25 @@ Refer to the documentation for the model you are fine-tuning for more details. {"prompt": "What is the capital of France?", "completion": "Paris."} ``` -`text`: - -```jsonl -{"text": "This is an example for the model."} -``` - -Note, the format is automatically determined by the dataset. - -For the completion data format, a different key can be used for the _prompt_ and for the _completion_ by specifying -the following, for example, in the YAML config: +For the `completions` data format, a different key can be used for the prompt +and completion by specifying the following in the YAML config: ```yaml prompt_feature: "input" completion_feature: "output" ``` -Here, `input` is now the expected key instead of "prompt" and `output` is the expected key instead of "completion". -Note also, keys in each line not expected by the loader will be ignored. +Here, `"input"` is the expected key instead of the default `"prompt"`, and +`"output"` is the expected key instead of `"completion"`. + +`text`: + +```jsonl +{"text": "This is an example for the model."} +``` + +Note, the format is automatically determined by the dataset. Note also, keys +in each line not expected by the loader will be ignored. > [!NOTE] > Each example in the datasets must be on a single line. Do not put more than diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 41a618b1..4d050bd5 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -61,8 +61,6 @@ CONFIG_DEFAULTS = { "config": None, "grad_checkpoint": False, "lr_schedule": None, - "prompt_feature": "prompt", - "completion_feature": "completion", "lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}, } diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 81e5d293..1b09c7e2 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -1,6 +1,6 @@ import json from pathlib import Path -from typing import Dict, List +from typing import Dict, List, Optional from transformers import PreTrainedTokenizer @@ -61,8 +61,8 @@ class CompletionsDataset: self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer, - prompt_key: str = "prompt", - completion_key: str = "completion", + prompt_key: str, + completion_key: str, ): self._data = [ tokenizer.apply_chat_template( @@ -81,17 +81,15 @@ class CompletionsDataset: return len(self._data) -<<<<<<< HEAD def create_dataset( data, tokenizer: PreTrainedTokenizer, prompt_feature: Optional[str] = None, completion_feature: Optional[str] = None, ): - sample = data[0] prompt_feature = prompt_feature or "prompt" completion_feature = completion_feature or "completion" - + sample = data[0] if "messages" in sample: return ChatDataset(data, tokenizer) elif prompt_feature in sample and completion_feature in sample: @@ -108,8 +106,8 @@ def create_dataset( def load_local_dataset( data_path: Path, tokenizer: PreTrainedTokenizer, - prompt_feature: str = None, - completion_feature: str = None, + prompt_feature: Optional[str] = None, + completion_feature: Optional[str] = None, ): def load_subset(path): if not path.exists(): @@ -126,8 +124,8 @@ def load_local_dataset( def load_hf_dataset( data_id: str, tokenizer: PreTrainedTokenizer, - prompt_feature: str = None, - completion_feature: str = None, + prompt_feature: Optional[str] = None, + completion_feature: Optional[str] = None, ): from datasets import exceptions, load_dataset @@ -199,14 +197,17 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer): train, valid, test = load_custom_hf_dataset(args, tokenizer) else: data_path = Path(args.data) + + prompt_feature = getattr(args, "prompt_feature", None) + completion_feature = getattr(args, "completion_feature", None) if data_path.exists(): train, valid, test = load_local_dataset( - data_path, tokenizer, args.prompt_feature, args.completion_feature + data_path, tokenizer, prompt_feature, completion_feature ) else: print(f"Loading Hugging Face dataset {args.data}.") train, valid, test = load_hf_dataset( - args.data, tokenizer, args.prompt_feature, args.completion_feature + args.data, tokenizer, prompt_feature, completion_feature ) if args.train and len(train) == 0: