From 7b0141455ef8091a4946076213e55426e6b78bd0 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Tue, 4 Feb 2025 10:43:00 +0100 Subject: [PATCH] better create_dataset --- llms/mlx_lm/tuner/datasets.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 23e293c4..983bd8e3 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -141,9 +141,19 @@ def create_dataset( completion_feature = completion_feature or "completion" sample = data[0] - if "messages" in sample: - return ChatDataset(data, tokenizer) - elif "prompt" in sample and "answer" in sample: + if args.training_mode == "normal": + if "messages" in sample: + return ChatDataset(data, tokenizer) + elif prompt_feature in sample and completion_feature in sample: + return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature) + elif "text" in sample: + return Dataset(data, tokenizer) + else: + raise ValueError( + "Unsupported data format, check the supported formats here:\n" + "https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md#data." + ) + else: return GRPODataset( data=data, tokenizer=tokenizer, @@ -152,15 +162,6 @@ def create_dataset( use_chat_template=args.use_chat_template, use_prompt=args.use_prompt ) - elif prompt_feature in sample and completion_feature in sample: - return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature) - elif "text" in sample: - return Dataset(data, tokenizer) - else: - raise ValueError( - "Unsupported data format, check the supported formats here:\n" - "https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md#data." - ) def load_local_dataset(