diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 2abea970..a7fb2177 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -117,54 +117,65 @@ def create_dataset(path: Path, tokenizer: PreTrainedTokenizer = None): ) +def create_local_dataset(args, tokenizer: PreTrainedTokenizer): + names = ("train", "valid", "test") + data_path = Path(args.data) + + train, valid, test = [ + create_dataset(data_path / f"{n}.jsonl", tokenizer) for n in names + ] + return train, valid, test + + +def create_hf_dataset(args, tokenizer: PreTrainedTokenizer): + import datasets + + hf_args = args.hf_dataset + dataset_name = hf_args["name"] + print(f"Loading Hugging Face dataset {dataset_name}.") + text_feature = hf_args.get("text_feature") + prompt_feature = hf_args.get("prompt_feature") + completion_feature = hf_args.get("completion_feature") + + def create_hf_dataset(split: str = None): + ds = datasets.load_dataset( + dataset_name, + split=split, + **hf_args.get("config", {}), + ) + if prompt_feature and completion_feature: + return CompletionsDataset( + ds, tokenizer, prompt_feature, completion_feature + ) + elif text_feature: + return Dataset(train_ds, text_key=text_feature) + else: + raise ValueError( + "Specify either a prompt and completion feature or a text " + "feature for the Hugging Face dataset." + ) + + if args.train: + train_split = hf_args.get("train_split", "train[:80%]") + valid_split = hf_args.get("valid_split", "train[-10%:]") + train = create_hf_dataset(split=train_split) + valid = create_hf_dataset(split=valid_split) + else: + train, valid = [], [] + if args.test: + test = create_hf_dataset(split=hf_args.get("test_split")) + else: + test = [] + + return train, valid, test + + def load_dataset(args, tokenizer: PreTrainedTokenizer): if getattr(args, "hf_dataset", None) is not None: - import datasets - - hf_args = args.hf_dataset - dataset_name = hf_args["name"] - print(f"Loading Hugging Face dataset {dataset_name}.") - text_feature = hf_args.get("text_feature") - prompt_feature = hf_args.get("prompt_feature") - completion_feature = hf_args.get("completion_feature") - - def create_hf_dataset(split: str = None): - ds = datasets.load_dataset( - dataset_name, - split=split, - **hf_args.get("config", {}), - ) - if prompt_feature and completion_feature: - return CompletionsDataset( - ds, tokenizer, prompt_feature, completion_feature - ) - elif text_feature: - return Dataset(train_ds, text_key=text_feature) - else: - raise ValueError( - "Specify either a prompt and completion feature or a text " - "feature for the Hugging Face dataset." - ) - - if args.train: - train_split = hf_args.get("train_split", "train[:80%]") - valid_split = hf_args.get("valid_split", "train[-10%:]") - train = create_hf_dataset(split=train_split) - valid = create_hf_dataset(split=valid_split) - else: - train, valid = [], [] - if args.test: - test = create_hf_dataset(split=hf_args.get("test_split")) - else: - test = [] - + train, valid, test = create_hf_dataset(args, tokenizer) else: - names = ("train", "valid", "test") - data_path = Path(args.data) + train, valid, test = create_local_dataset(args, tokenizer) - train, valid, test = [ - create_dataset(data_path / f"{n}.jsonl", tokenizer) for n in names - ] if args.train and len(train) == 0: raise ValueError( "Training set not found or empty. Must provide training set for fine-tuning."