LoRA: Support HuggingFace dataset via data parameter (#996)

* LoRA: support huggingface dataset via `data` argument

* LoRA: Extract the load_custom_hf_dataset function

* LoRA: split small functions

* fix spelling errors

* handle load hf dataset error

* fix pre-commit lint

* update data argument help

* nits and doc

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
madroid
2024-09-30 22:36:21 +08:00
committed by GitHub
parent 50e5ca81a8
commit aa1c8abdc6
3 changed files with 93 additions and 51 deletions

View File

@@ -251,7 +251,13 @@ To use Hugging Face datasets, first install the `datasets` package:
pip install datasets pip install datasets
``` ```
Specify the Hugging Face dataset arguments in a YAML config. For example: If the Hugging Face dataset is already in a supported format, you can specify
it on the command line. For example, pass `--data mlx-community/wikisql` to
train on the pre-formatted WikiwSQL data.
Otherwise, provide a mapping of keys in the dataset to the features MLX LM
expects. Use a YAML config to specify the Hugging Face dataset arguments. For
example:
``` ```
hf_dataset: hf_dataset:

View File

@@ -79,7 +79,10 @@ def build_parser():
parser.add_argument( parser.add_argument(
"--data", "--data",
type=str, type=str,
help="Directory with {train, valid, test}.jsonl files", help=(
"Directory with {train, valid, test}.jsonl files or the name "
"of a Hugging Face dataset (e.g., 'mlx-community/wikisql')"
),
) )
parser.add_argument( parser.add_argument(
"--fine-tune-type", "--fine-tune-type",

View File

@@ -76,17 +76,14 @@ class CompletionsDataset(Dataset):
return text return text
def create_dataset(path: Path, tokenizer: PreTrainedTokenizer = None): def create_dataset(data, tokenizer: PreTrainedTokenizer = None):
# Return empty dataset for non-existent paths sample = data[0]
if not path.exists():
return [] if "messages" in sample:
with open(path, "r") as fid:
data = [json.loads(l) for l in fid]
if "messages" in data[0]:
return ChatDataset(data, tokenizer) return ChatDataset(data, tokenizer)
elif "prompt" in data[0] and "completion" in data[0]: elif "prompt" in sample and "completion" in sample:
return CompletionsDataset(data, tokenizer) return CompletionsDataset(data, tokenizer)
elif "text" in data[0]: elif "text" in sample:
return Dataset(data) return Dataset(data)
else: else:
raise ValueError( raise ValueError(
@@ -95,8 +92,39 @@ def create_dataset(path: Path, tokenizer: PreTrainedTokenizer = None):
) )
def load_dataset(args, tokenizer: PreTrainedTokenizer): def load_local_dataset(data_path: Path, tokenizer: PreTrainedTokenizer):
if getattr(args, "hf_dataset", None) is not None: def load_subset(path):
if not path.exists():
return []
with open(path, "r") as fid:
data = [json.loads(l) for l in fid]
return create_dataset(data, tokenizer)
names = ("train", "valid", "test")
train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names]
return train, valid, test
def load_hf_dataset(data_id: str, tokenizer: PreTrainedTokenizer):
from datasets import exceptions, load_dataset
try:
dataset = load_dataset(data_id)
names = ("train", "valid", "test")
train, valid, test = [
create_dataset(dataset[n], tokenizer) if n in dataset.keys() else []
for n in names
]
except exceptions.DatasetNotFoundError:
raise ValueError(f"Not found Hugging Face dataset: {data_id} .")
return train, valid, test
def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
import datasets import datasets
hf_args = args.hf_dataset hf_args = args.hf_dataset
@@ -113,9 +141,7 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer):
**hf_args.get("config", {}), **hf_args.get("config", {}),
) )
if prompt_feature and completion_feature: if prompt_feature and completion_feature:
return CompletionsDataset( return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature)
ds, tokenizer, prompt_feature, completion_feature
)
elif text_feature: elif text_feature:
return Dataset(train_ds, text_key=text_feature) return Dataset(train_ds, text_key=text_feature)
else: else:
@@ -136,13 +162,20 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer):
else: else:
test = [] test = []
else: return train, valid, test
names = ("train", "valid", "test")
data_path = Path(args.data)
def load_dataset(args, tokenizer: PreTrainedTokenizer):
if getattr(args, "hf_dataset", None) is not None:
train, valid, test = load_custom_hf_dataset(args, tokenizer)
else:
data_path = Path(args.data)
if data_path.exists():
train, valid, test = load_local_dataset(data_path, tokenizer)
else:
print(f"Loading Hugging Face dataset {args.data}.")
train, valid, test = load_hf_dataset(args.data, tokenizer)
train, valid, test = [
create_dataset(data_path / f"{n}.jsonl", tokenizer) for n in names
]
if args.train and len(train) == 0: if args.train and len(train) == 0:
raise ValueError( raise ValueError(
"Training set not found or empty. Must provide training set for fine-tuning." "Training set not found or empty. Must provide training set for fine-tuning."