diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index bef354c9..5e4e7ece 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -254,7 +254,7 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer): def load_dataset(args, tokenizer: PreTrainedTokenizer): - if getattr(args, "hf_dataset", False): + if getattr(args, "hf_dataset", False) or getattr(args, "hf_datasets", False): train, valid, test = load_custom_hf_dataset(args, tokenizer) else: data_path = Path(args.data)