diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 7d0e5026..2e617cf0 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -108,6 +108,14 @@ class CompletionsDatasetCollection: return self.__fetch_and_process_item__(idx, getitem) + def get_item( + self, idx: int, tokenize: bool = False, add_generation_prompt: bool = True + ) -> str: + def getitem(dataset: CompletionsDataset, index: int): + return dataset.get_item(index, tokenize, add_generation_prompt) + + return self.__fetch_and_process_item__(idx, getitem) + def get_prompt_and_completion(self, idx: int): def getitem(dataset: CompletionsDataset, index: int): return dataset.get_prompt_and_completion(index)