From a7e414687e89d0c457ae0a555053593cf4aa5d37 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Tue, 4 Feb 2025 10:45:23 +0100 Subject: [PATCH] update create_dataset --- llms/mlx_lm/tuner/datasets.py | 36 +++++++++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 377e7cae..f753b3f6 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -102,8 +102,35 @@ def create_dataset( "https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md#data." ) +def create_dataset( + args, + data, + tokenizer: PreTrainedTokenizer, + prompt_feature: Optional[str] = None, + completion_feature: Optional[str] = None, +): + prompt_feature = prompt_feature or "prompt" + completion_feature = completion_feature or "completion" + sample = data[0] + + if args.training_mode == "normal": + if "messages" in sample: + return ChatDataset(data, tokenizer) + elif prompt_feature in sample and completion_feature in sample: + return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature) + elif "text" in sample: + return Dataset(data, tokenizer) + else: + raise ValueError( + "Unsupported data format, check the supported formats here:\n" + "https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md#data." + ) + else: + return "" + def load_local_dataset( + args, data_path: Path, tokenizer: PreTrainedTokenizer, prompt_feature: Optional[str] = None, @@ -114,7 +141,7 @@ def load_local_dataset( return [] with open(path, "r") as fid: data = [json.loads(l) for l in fid] - return create_dataset(data, tokenizer, prompt_feature, completion_feature) + return create_dataset(args, data, tokenizer, prompt_feature, completion_feature) names = ("train", "valid", "test") train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names] @@ -122,6 +149,7 @@ def load_local_dataset( def load_hf_dataset( + args, data_id: str, tokenizer: PreTrainedTokenizer, prompt_feature: Optional[str] = None, @@ -137,7 +165,7 @@ def load_hf_dataset( train, valid, test = [ ( create_dataset( - dataset[n], tokenizer, prompt_feature, completion_feature + args, dataset[n], tokenizer, prompt_feature, completion_feature ) if n in dataset.keys() else [] @@ -202,12 +230,12 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer): completion_feature = getattr(args, "completion_feature", None) if data_path.exists(): train, valid, test = load_local_dataset( - data_path, tokenizer, prompt_feature, completion_feature + args, data_path, tokenizer, prompt_feature, completion_feature ) else: print(f"Loading Hugging Face dataset {args.data}.") train, valid, test = load_hf_dataset( - args.data, tokenizer, prompt_feature, completion_feature + args, args.data, tokenizer, prompt_feature, completion_feature ) if args.train and len(train) == 0: