From e45ce38f8699961141a38d4b7399e8c184bd78a8 Mon Sep 17 00:00:00 2001 From: Chime Ogbuji Date: Wed, 6 Nov 2024 12:53:54 -0500 Subject: [PATCH] Add ability to fetch raw prompt and completion text from completion datasets --- llms/mlx_lm/tuner/datasets.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index da2e6f22..3130d9f9 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -1,6 +1,6 @@ import json from pathlib import Path -from typing import Dict, List, Union +from typing import Callable, Dict, List, Union from transformers import PreTrainedTokenizer @@ -81,12 +81,15 @@ class CompletionsDataset(Dataset): ) return text + def get_prompt_and_completion(self, idx: int): + return self._data[idx][self._prompt_key], self._data[idx][self._completion_key] + class CompletionsDatasetCollection: def __init__(self, data: List[Union[ChatDataset, CompletionsDataset]]): self.collection = data - def __getitem__(self, idx: int): + def __fetch_and_process_item__(self, idx: int, handler_fn: Callable): iteration = iter(self.collection) item = next(iteration) @@ -95,13 +98,25 @@ class CompletionsDatasetCollection: while True: try: if (curr_idx + 1) <= len(item): - return item[curr_idx] + return handler_fn(item, curr_idx) else: curr_idx -= len(item) item = next(iteration) except StopIteration: raise IndexError(idx) + def __getitem__(self, idx: int): + def getitem(dataset: CompletionsDataset, index: int): + return dataset[index] + + return self.__fetch_and_process_item__(idx, getitem) + + def get_prompt_and_completion(self, idx: int): + def getitem(dataset: CompletionsDataset, index: int): + dataset.get_prompt_and_completion(index) + + return self.__fetch_and_process_item__(idx, getitem) + def __len__(self): return sum(map(len, self.collection))