rebase + nits

This commit is contained in:
Awni Hannun 2025-01-13 09:43:23 -08:00
parent 40438b1371
commit 7499720b09
3 changed files with 26 additions and 26 deletions

View File

@ -241,24 +241,25 @@ Refer to the documentation for the model you are fine-tuning for more details.
{"prompt": "What is the capital of France?", "completion": "Paris."} {"prompt": "What is the capital of France?", "completion": "Paris."}
``` ```
`text`: For the `completions` data format, a different key can be used for the prompt
and completion by specifying the following in the YAML config:
```jsonl
{"text": "This is an example for the model."}
```
Note, the format is automatically determined by the dataset.
For the completion data format, a different key can be used for the _prompt_ and for the _completion_ by specifying
the following, for example, in the YAML config:
```yaml ```yaml
prompt_feature: "input" prompt_feature: "input"
completion_feature: "output" completion_feature: "output"
``` ```
Here, `input` is now the expected key instead of "prompt" and `output` is the expected key instead of "completion". Here, `"input"` is the expected key instead of the default `"prompt"`, and
Note also, keys in each line not expected by the loader will be ignored. `"output"` is the expected key instead of `"completion"`.
`text`:
```jsonl
{"text": "This is an example for the model."}
```
Note, the format is automatically determined by the dataset. Note also, keys
in each line not expected by the loader will be ignored.
> [!NOTE] > [!NOTE]
> Each example in the datasets must be on a single line. Do not put more than > Each example in the datasets must be on a single line. Do not put more than

View File

@ -61,8 +61,6 @@ CONFIG_DEFAULTS = {
"config": None, "config": None,
"grad_checkpoint": False, "grad_checkpoint": False,
"lr_schedule": None, "lr_schedule": None,
"prompt_feature": "prompt",
"completion_feature": "completion",
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}, "lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
} }

View File

@ -1,6 +1,6 @@
import json import json
from pathlib import Path from pathlib import Path
from typing import Dict, List from typing import Dict, List, Optional
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
@ -61,8 +61,8 @@ class CompletionsDataset:
self, self,
data: List[Dict[str, str]], data: List[Dict[str, str]],
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
prompt_key: str = "prompt", prompt_key: str,
completion_key: str = "completion", completion_key: str,
): ):
self._data = [ self._data = [
tokenizer.apply_chat_template( tokenizer.apply_chat_template(
@ -81,17 +81,15 @@ class CompletionsDataset:
return len(self._data) return len(self._data)
<<<<<<< HEAD
def create_dataset( def create_dataset(
data, data,
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
prompt_feature: Optional[str] = None, prompt_feature: Optional[str] = None,
completion_feature: Optional[str] = None, completion_feature: Optional[str] = None,
): ):
sample = data[0]
prompt_feature = prompt_feature or "prompt" prompt_feature = prompt_feature or "prompt"
completion_feature = completion_feature or "completion" completion_feature = completion_feature or "completion"
sample = data[0]
if "messages" in sample: if "messages" in sample:
return ChatDataset(data, tokenizer) return ChatDataset(data, tokenizer)
elif prompt_feature in sample and completion_feature in sample: elif prompt_feature in sample and completion_feature in sample:
@ -108,8 +106,8 @@ def create_dataset(
def load_local_dataset( def load_local_dataset(
data_path: Path, data_path: Path,
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
prompt_feature: str = None, prompt_feature: Optional[str] = None,
completion_feature: str = None, completion_feature: Optional[str] = None,
): ):
def load_subset(path): def load_subset(path):
if not path.exists(): if not path.exists():
@ -126,8 +124,8 @@ def load_local_dataset(
def load_hf_dataset( def load_hf_dataset(
data_id: str, data_id: str,
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
prompt_feature: str = None, prompt_feature: Optional[str] = None,
completion_feature: str = None, completion_feature: Optional[str] = None,
): ):
from datasets import exceptions, load_dataset from datasets import exceptions, load_dataset
@ -199,14 +197,17 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer):
train, valid, test = load_custom_hf_dataset(args, tokenizer) train, valid, test = load_custom_hf_dataset(args, tokenizer)
else: else:
data_path = Path(args.data) data_path = Path(args.data)
prompt_feature = getattr(args, "prompt_feature", None)
completion_feature = getattr(args, "completion_feature", None)
if data_path.exists(): if data_path.exists():
train, valid, test = load_local_dataset( train, valid, test = load_local_dataset(
data_path, tokenizer, args.prompt_feature, args.completion_feature data_path, tokenizer, prompt_feature, completion_feature
) )
else: else:
print(f"Loading Hugging Face dataset {args.data}.") print(f"Loading Hugging Face dataset {args.data}.")
train, valid, test = load_hf_dataset( train, valid, test = load_hf_dataset(
args.data, tokenizer, args.prompt_feature, args.completion_feature args.data, tokenizer, prompt_feature, completion_feature
) )
if args.train and len(train) == 0: if args.train and len(train) == 0: