Add ability to fetch raw prompt and completion text from completion datasets

This commit is contained in:
Chime Ogbuji 2024-11-06 12:53:54 -05:00
parent 78b24a2375
commit e45ce38f86

View File

@ -1,6 +1,6 @@
import json import json
from pathlib import Path from pathlib import Path
from typing import Dict, List, Union from typing import Callable, Dict, List, Union
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
@ -81,12 +81,15 @@ class CompletionsDataset(Dataset):
) )
return text return text
def get_prompt_and_completion(self, idx: int):
return self._data[idx][self._prompt_key], self._data[idx][self._completion_key]
class CompletionsDatasetCollection: class CompletionsDatasetCollection:
def __init__(self, data: List[Union[ChatDataset, CompletionsDataset]]): def __init__(self, data: List[Union[ChatDataset, CompletionsDataset]]):
self.collection = data self.collection = data
def __getitem__(self, idx: int): def __fetch_and_process_item__(self, idx: int, handler_fn: Callable):
iteration = iter(self.collection) iteration = iter(self.collection)
item = next(iteration) item = next(iteration)
@ -95,13 +98,25 @@ class CompletionsDatasetCollection:
while True: while True:
try: try:
if (curr_idx + 1) <= len(item): if (curr_idx + 1) <= len(item):
return item[curr_idx] return handler_fn(item, curr_idx)
else: else:
curr_idx -= len(item) curr_idx -= len(item)
item = next(iteration) item = next(iteration)
except StopIteration: except StopIteration:
raise IndexError(idx) raise IndexError(idx)
def __getitem__(self, idx: int):
def getitem(dataset: CompletionsDataset, index: int):
return dataset[index]
return self.__fetch_and_process_item__(idx, getitem)
def get_prompt_and_completion(self, idx: int):
def getitem(dataset: CompletionsDataset, index: int):
dataset.get_prompt_and_completion(index)
return self.__fetch_and_process_item__(idx, getitem)
def __len__(self): def __len__(self):
return sum(map(len, self.collection)) return sum(map(len, self.collection))