LoRA: Split small function

This commit is contained in:
madroid 2024-09-20 11:06:06 +08:00
parent bfd4ba2347
commit 3f6a5f19fd

View File

@ -117,8 +117,17 @@ def create_dataset(path: Path, tokenizer: PreTrainedTokenizer = None):
)
def load_dataset(args, tokenizer: PreTrainedTokenizer):
if getattr(args, "hf_dataset", None) is not None:
def create_local_dataset(args, tokenizer: PreTrainedTokenizer):
names = ("train", "valid", "test")
data_path = Path(args.data)
train, valid, test = [
create_dataset(data_path / f"{n}.jsonl", tokenizer) for n in names
]
return train, valid, test
def create_hf_dataset(args, tokenizer: PreTrainedTokenizer):
import datasets
hf_args = args.hf_dataset
@ -158,13 +167,15 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer):
else:
test = []
else:
names = ("train", "valid", "test")
data_path = Path(args.data)
return train, valid, test
def load_dataset(args, tokenizer: PreTrainedTokenizer):
if getattr(args, "hf_dataset", None) is not None:
train, valid, test = create_hf_dataset(args, tokenizer)
else:
train, valid, test = create_local_dataset(args, tokenizer)
train, valid, test = [
create_dataset(data_path / f"{n}.jsonl", tokenizer) for n in names
]
if args.train and len(train) == 0:
raise ValueError(
"Training set not found or empty. Must provide training set for fine-tuning."