Fixes to references to hf_datasets

This commit is contained in:
Chime Ogbuji 2024-11-03 20:04:15 -05:00
parent c72122064a
commit 04cf93df55

View File

@ -201,7 +201,7 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
) )
return train, valid return train, valid
if args.datasets: if args.hf_datasets:
dataset_collection = args.hf_datasets dataset_collection = args.hf_datasets
train_collection = [] train_collection = []
valid_collection = [] valid_collection = []
@ -263,7 +263,7 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
def load_dataset(args, tokenizer: PreTrainedTokenizer): def load_dataset(args, tokenizer: PreTrainedTokenizer):
if getattr(args, "hf_dataset", None) is not None: if getattr(args, "hf_dataset", None) is not None or getattr(args, "hf_datasets"):
train, valid, test = load_custom_hf_dataset(args, tokenizer) train, valid, test = load_custom_hf_dataset(args, tokenizer)
else: else:
data_path = Path(args.data) data_path = Path(args.data)