mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-28 00:30:09 +08:00
Add ability to fetch raw prompt and completion text from completion datasets
This commit is contained in:
parent
a5b866cf73
commit
4890870053
@ -84,12 +84,15 @@ class CompletionsDataset:
|
||||
def __len__(self):
|
||||
return len(self._data)
|
||||
|
||||
def get_prompt_and_completion(self, idx: int):
|
||||
return self._data[idx][self._prompt_key], self._data[idx][self._completion_key]
|
||||
|
||||
|
||||
class CompletionsDatasetCollection:
|
||||
def __init__(self, data: List[Union[ChatDataset, CompletionsDataset]]):
|
||||
self.collection = data
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
def __fetch_and_process_item__(self, idx: int, handler_fn: Callable):
|
||||
iteration = iter(self.collection)
|
||||
item = next(iteration)
|
||||
|
||||
@ -98,13 +101,25 @@ class CompletionsDatasetCollection:
|
||||
while True:
|
||||
try:
|
||||
if (curr_idx + 1) <= len(item):
|
||||
return item[curr_idx]
|
||||
return handler_fn(item, curr_idx)
|
||||
else:
|
||||
curr_idx -= len(item)
|
||||
item = next(iteration)
|
||||
except StopIteration:
|
||||
raise IndexError(idx)
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
def getitem(dataset: CompletionsDataset, index: int):
|
||||
return dataset[index]
|
||||
|
||||
return self.__fetch_and_process_item__(idx, getitem)
|
||||
|
||||
def get_prompt_and_completion(self, idx: int):
|
||||
def getitem(dataset: CompletionsDataset, index: int):
|
||||
dataset.get_prompt_and_completion(index)
|
||||
|
||||
return self.__fetch_and_process_item__(idx, getitem)
|
||||
|
||||
def __len__(self):
|
||||
return sum(map(len, self.collection))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user