Add ability to fetch raw prompt and completion text from completion datasets

This commit is contained in:
Chime Ogbuji 2024-11-06 12:53:54 -05:00
parent 78b24a2375
commit e45ce38f86

View File

@ -1,6 +1,6 @@
import json
from pathlib import Path
from typing import Dict, List, Union
from typing import Callable, Dict, List, Union
from transformers import PreTrainedTokenizer
@ -81,12 +81,15 @@ class CompletionsDataset(Dataset):
)
return text
def get_prompt_and_completion(self, idx: int):
return self._data[idx][self._prompt_key], self._data[idx][self._completion_key]
class CompletionsDatasetCollection:
def __init__(self, data: List[Union[ChatDataset, CompletionsDataset]]):
self.collection = data
def __getitem__(self, idx: int):
def __fetch_and_process_item__(self, idx: int, handler_fn: Callable):
iteration = iter(self.collection)
item = next(iteration)
@ -95,13 +98,25 @@ class CompletionsDatasetCollection:
while True:
try:
if (curr_idx + 1) <= len(item):
return item[curr_idx]
return handler_fn(item, curr_idx)
else:
curr_idx -= len(item)
item = next(iteration)
except StopIteration:
raise IndexError(idx)
def __getitem__(self, idx: int):
def getitem(dataset: CompletionsDataset, index: int):
return dataset[index]
return self.__fetch_and_process_item__(idx, getitem)
def get_prompt_and_completion(self, idx: int):
def getitem(dataset: CompletionsDataset, index: int):
dataset.get_prompt_and_completion(index)
return self.__fetch_and_process_item__(idx, getitem)
def __len__(self):
return sum(map(len, self.collection))