diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 7bae2862..f1345634 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -84,12 +84,15 @@ class CompletionsDataset: def __len__(self): return len(self._data) + def get_prompt_and_completion(self, idx: int): + return self._data[idx][self._prompt_key], self._data[idx][self._completion_key] + class CompletionsDatasetCollection: def __init__(self, data: List[Union[ChatDataset, CompletionsDataset]]): self.collection = data - def __getitem__(self, idx: int): + def __fetch_and_process_item__(self, idx: int, handler_fn: Callable): iteration = iter(self.collection) item = next(iteration) @@ -98,13 +101,25 @@ class CompletionsDatasetCollection: while True: try: if (curr_idx + 1) <= len(item): - return item[curr_idx] + return handler_fn(item, curr_idx) else: curr_idx -= len(item) item = next(iteration) except StopIteration: raise IndexError(idx) + def __getitem__(self, idx: int): + def getitem(dataset: CompletionsDataset, index: int): + return dataset[index] + + return self.__fetch_and_process_item__(idx, getitem) + + def get_prompt_and_completion(self, idx: int): + def getitem(dataset: CompletionsDataset, index: int): + dataset.get_prompt_and_completion(index) + + return self.__fetch_and_process_item__(idx, getitem) + def __len__(self): return sum(map(len, self.collection))