From 1beefd58a044c0af9eeb0a58c9901465d61b37df Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Tue, 4 Feb 2025 11:06:57 +0100 Subject: [PATCH] add create_dataset --- llms/mlx_lm/tuner/datasets.py | 41 +++++++++++++++++++++-------------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 2330d4ac..27b74b34 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -140,6 +140,7 @@ class CompletionsDataset: def create_dataset( + args, data, tokenizer: PreTrainedTokenizer, prompt_feature: Optional[str] = None, @@ -148,24 +149,31 @@ def create_dataset( prompt_feature = prompt_feature or "prompt" completion_feature = completion_feature or "completion" sample = data[0] - - # Add DPO dataset support - if "chosen" in sample and "rejected" in sample: - return ORPODataset(data, tokenizer) - elif "messages" in sample: - return ChatDataset(data, tokenizer) - elif prompt_feature in sample and completion_feature in sample: - return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature) - elif "text" in sample: - return Dataset(data, tokenizer) + + if args.training_mode == "normal": + if "messages" in sample: + return ChatDataset(data, tokenizer) + elif prompt_feature in sample and completion_feature in sample: + return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature) + elif "text" in sample: + return Dataset(data, tokenizer) + else: + raise ValueError( + "Unsupported data format, check the supported formats here:\n" + "https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md#data." + ) + elif args.training_mode == "orpo": + if "chosen" in sample and "rejected" in sample: + return ORPODataset(data, tokenizer) else: raise ValueError( - "Unsupported data format, check the supported formats here:\n" - "https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md#data." + "Unsupported training mode, check the supported training modes and their formats here:\n" + "https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md#training-modes." ) def load_local_dataset( + args, data_path: Path, tokenizer: PreTrainedTokenizer, prompt_feature: Optional[str] = None, @@ -176,7 +184,7 @@ def load_local_dataset( return [] with open(path, "r") as fid: data = [json.loads(l) for l in fid] - return create_dataset(data, tokenizer, prompt_feature, completion_feature) + return create_dataset(args, data, tokenizer, prompt_feature, completion_feature) names = ("train", "valid", "test") train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names] @@ -184,6 +192,7 @@ def load_local_dataset( def load_hf_dataset( + args, data_id: str, tokenizer: PreTrainedTokenizer, prompt_feature: Optional[str] = None, @@ -199,7 +208,7 @@ def load_hf_dataset( train, valid, test = [ ( create_dataset( - dataset[n], tokenizer, prompt_feature, completion_feature + args, dataset[n], tokenizer, prompt_feature, completion_feature ) if n in dataset.keys() else [] @@ -264,12 +273,12 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer): completion_feature = getattr(args, "completion_feature", None) if data_path.exists(): train, valid, test = load_local_dataset( - data_path, tokenizer, prompt_feature, completion_feature + args, data_path, tokenizer, prompt_feature, completion_feature ) else: print(f"Loading Hugging Face dataset {args.data}.") train, valid, test = load_hf_dataset( - args.data, tokenizer, prompt_feature, completion_feature + args, args.data, tokenizer, prompt_feature, completion_feature ) if args.train and len(train) == 0: