diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index a0c16a28..3b442c6a 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -201,7 +201,7 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer): ) return train, valid - if args.datasets: + if args.hf_datasets: dataset_collection = args.hf_datasets train_collection = [] valid_collection = [] @@ -263,7 +263,7 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer): def load_dataset(args, tokenizer: PreTrainedTokenizer): - if getattr(args, "hf_dataset", None) is not None: + if getattr(args, "hf_dataset", None) is not None or getattr(args, "hf_datasets"): train, valid, test = load_custom_hf_dataset(args, tokenizer) else: data_path = Path(args.data)