Add ability to fetch raw prompt and completion text from completion datasets

This commit is contained in:
Chime Ogbuji 2024-11-06 12:53:54 -05:00 committed by Awni Hannun
parent a5b866cf73
commit 4890870053

View File

@ -84,12 +84,15 @@ class CompletionsDataset:
def __len__(self):
return len(self._data)
def get_prompt_and_completion(self, idx: int):
return self._data[idx][self._prompt_key], self._data[idx][self._completion_key]
class CompletionsDatasetCollection:
def __init__(self, data: List[Union[ChatDataset, CompletionsDataset]]):
self.collection = data
def __getitem__(self, idx: int):
def __fetch_and_process_item__(self, idx: int, handler_fn: Callable):
iteration = iter(self.collection)
item = next(iteration)
@ -98,13 +101,25 @@ class CompletionsDatasetCollection:
while True:
try:
if (curr_idx + 1) <= len(item):
return item[curr_idx]
return handler_fn(item, curr_idx)
else:
curr_idx -= len(item)
item = next(iteration)
except StopIteration:
raise IndexError(idx)
def __getitem__(self, idx: int):
def getitem(dataset: CompletionsDataset, index: int):
return dataset[index]
return self.__fetch_and_process_item__(idx, getitem)
def get_prompt_and_completion(self, idx: int):
def getitem(dataset: CompletionsDataset, index: int):
dataset.get_prompt_and_completion(index)
return self.__fetch_and_process_item__(idx, getitem)
def __len__(self):
return sum(map(len, self.collection))