Generalize prompt_feature and completion_feature for use in local datasets to facilitate compatibility with many other training dataset formats.

This commit is contained in:
Chime Ogbuji 2024-11-02 18:05:27 -04:00 committed by Awni Hannun
parent bf2da36fc6
commit db9898d104
3 changed files with 51 additions and 11 deletions

View File

@ -247,8 +247,18 @@ Refer to the documentation for the model you are fine-tuning for more details.
{"text": "This is an example for the model."}
```
Note, the format is automatically determined by the dataset. Note also, keys in
each line not expected by the loader will be ignored.
Note, the format is automatically determined by the dataset.
For the completion data format, a different key can be used for the _prompt_ and for the _completion_ by specifying
the following, for example, in the YAML config:
```yaml
prompt_feature: "input"
completion_feature: "output"
```
Here, `input` is now the expected key instead of "prompt" and `output` is the expected key instead of "completion".
Note also, keys in each line not expected by the loader will be ignored.
> [!NOTE]
> Each example in the datasets must be on a single line. Do not put more than
@ -270,7 +280,7 @@ Otherwise, provide a mapping of keys in the dataset to the features MLX LM
expects. Use a YAML config to specify the Hugging Face dataset arguments. For
example:
```
```yaml
hf_dataset:
name: "billsum"
prompt_feature: "text"

View File

@ -61,6 +61,8 @@ CONFIG_DEFAULTS = {
"config": None,
"grad_checkpoint": False,
"lr_schedule": None,
"prompt_feature": "prompt",
"completion_feature": "completion",
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
}

View File

@ -81,12 +81,20 @@ class CompletionsDataset:
return len(self._data)
def create_dataset(data, tokenizer: PreTrainedTokenizer):
<<<<<<< HEAD
def create_dataset(
data,
tokenizer: PreTrainedTokenizer,
prompt_feature: Optional[str] = None,
completion_feature: Optional[str] = None,
):
sample = data[0]
prompt_feature = prompt_feature or "prompt"
completion_feature = completion_feature or "completion"
if "messages" in sample:
return ChatDataset(data, tokenizer)
elif "prompt" in sample and "completion" in sample:
elif prompt_feature in sample and completion_feature in sample:
return CompletionsDataset(data, tokenizer)
elif "text" in sample:
return Dataset(data, tokenizer)
@ -97,20 +105,30 @@ def create_dataset(data, tokenizer: PreTrainedTokenizer):
)
def load_local_dataset(data_path: Path, tokenizer: PreTrainedTokenizer):
def load_local_dataset(
data_path: Path,
tokenizer: PreTrainedTokenizer,
prompt_feature: str = None,
completion_feature: str = None,
):
def load_subset(path):
if not path.exists():
return []
with open(path, "r") as fid:
data = [json.loads(l) for l in fid]
return create_dataset(data, tokenizer)
return create_dataset(data, tokenizer, prompt_feature, completion_feature)
names = ("train", "valid", "test")
train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names]
return train, valid, test
def load_hf_dataset(data_id: str, tokenizer: PreTrainedTokenizer):
def load_hf_dataset(
data_id: str,
tokenizer: PreTrainedTokenizer,
prompt_feature: str = None,
completion_feature: str = None,
):
from datasets import exceptions, load_dataset
try:
@ -119,7 +137,13 @@ def load_hf_dataset(data_id: str, tokenizer: PreTrainedTokenizer):
names = ("train", "valid", "test")
train, valid, test = [
create_dataset(dataset[n], tokenizer) if n in dataset.keys() else []
(
create_dataset(
dataset[n], tokenizer, prompt_feature, completion_feature
)
if n in dataset.keys()
else []
)
for n in names
]
@ -176,10 +200,14 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer):
else:
data_path = Path(args.data)
if data_path.exists():
train, valid, test = load_local_dataset(data_path, tokenizer)
train, valid, test = load_local_dataset(
data_path, tokenizer, args.prompt_feature, args.completion_feature
)
else:
print(f"Loading Hugging Face dataset {args.data}.")
train, valid, test = load_hf_dataset(args.data, tokenizer)
train, valid, test = load_hf_dataset(
args.data, tokenizer, args.prompt_feature, args.completion_feature
)
if args.train and len(train) == 0:
raise ValueError(