From b9748e9ee470ac9ddd3554fbe485c1d3554dee7f Mon Sep 17 00:00:00 2001 From: Chime Ogbuji Date: Mon, 9 Dec 2024 11:46:00 -0500 Subject: [PATCH] Generalize the get_item method to all CompletionDatasets --- llms/mlx_lm/tuner/datasets.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 7d0e5026..2e617cf0 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -108,6 +108,14 @@ class CompletionsDatasetCollection: return self.__fetch_and_process_item__(idx, getitem) + def get_item( + self, idx: int, tokenize: bool = False, add_generation_prompt: bool = True + ) -> str: + def getitem(dataset: CompletionsDataset, index: int): + return dataset.get_item(index, tokenize, add_generation_prompt) + + return self.__fetch_and_process_item__(idx, getitem) + def get_prompt_and_completion(self, idx: int): def getitem(dataset: CompletionsDataset, index: int): return dataset.get_prompt_and_completion(index)