diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 692bdb5c..81e5d293 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -95,7 +95,7 @@ def create_dataset( if "messages" in sample: return ChatDataset(data, tokenizer) elif prompt_feature in sample and completion_feature in sample: - return CompletionsDataset(data, tokenizer) + return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature) elif "text" in sample: return Dataset(data, tokenizer) else: