diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index f26e973b..b80b26f8 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -73,17 +73,14 @@ class CompletionsDataset(Dataset): return text -def create_dataset(path: Path, tokenizer: PreTrainedTokenizer = None): - # Return empty dataset for non-existent paths - if not path.exists(): - return [] - with open(path, "r") as fid: - data = [json.loads(l) for l in fid] - if "messages" in data[0]: +def create_dataset(data, tokenizer: PreTrainedTokenizer = None): + sample = data[0] + + if "messages" in sample: return ChatDataset(data, tokenizer) - elif "prompt" in data[0] and "completion" in data[0]: + elif "prompt" in sample and "completion" in sample: return CompletionsDataset(data, tokenizer) - elif "text" in data[0]: + elif "text" in sample: return Dataset(data) else: raise ValueError( @@ -92,31 +89,31 @@ def create_dataset(path: Path, tokenizer: PreTrainedTokenizer = None): ) +def load_local_data(path: Path, tokenizer: PreTrainedTokenizer): + if not path.exists(): + return [] + with open(path, "r") as fid: + data = [json.loads(l) for l in fid] + + return create_dataset(data, tokenizer) + + +def load_local_dataset(data_path: Path, tokenizer: PreTrainedTokenizer): + names = ("train", "valid", "test") + train, valid, test = [ + load_local_data(data_path / f"{n}.jsonl", tokenizer) for n in names + ] + return train, valid, test + + def load_hf_dataset(data_id: str, tokenizer: PreTrainedTokenizer): import datasets datasets = datasets.load_dataset(data_id) - def create(data): - sample = data[0] - - if "messages" in sample: - return ChatDataset(data, tokenizer) - elif "prompt" in sample and "completion" in sample: - return CompletionsDataset(data, tokenizer) - elif "text" in sample: - return Dataset(data) - else: - raise ValueError( - "Unsupported data format, check the supported formats here:\n" - "https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md#data." - ) - names = ("train", "valid", "test") - train, valid, test = [ - create(datasets[n], tokenizer) for n in names - ] + train, valid, test = [create_dataset(datasets[n], tokenizer) for n in names] return train, valid, test @@ -137,9 +134,7 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer): **hf_args.get("config", {}), ) if prompt_feature and completion_feature: - return CompletionsDataset( - ds, tokenizer, prompt_feature, completion_feature - ) + return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature) elif text_feature: return Dataset(train_ds, text_key=text_feature) else: @@ -169,11 +164,9 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer): else: data_path = Path(args.data) if data_path.exists(): - names = ("train", "valid", "test") - train, valid, test = [ - create_dataset(data_path / f"{n}.jsonl", tokenizer) for n in names - ] + train, valid, test = load_local_dataset(args.data, tokenizer) else: + print(f"Loading Hugging Face dataset {args.data}.") train, valid, test = load_hf_dataset(args.data, tokenizer) if args.train and len(train) == 0: