mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-28 15:08:37 +08:00
Add ability to fetch raw prompt and completion text from completion datasets
This commit is contained in:
parent
a5b866cf73
commit
4890870053
@ -84,12 +84,15 @@ class CompletionsDataset:
|
|||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self._data)
|
return len(self._data)
|
||||||
|
|
||||||
|
def get_prompt_and_completion(self, idx: int):
|
||||||
|
return self._data[idx][self._prompt_key], self._data[idx][self._completion_key]
|
||||||
|
|
||||||
|
|
||||||
class CompletionsDatasetCollection:
|
class CompletionsDatasetCollection:
|
||||||
def __init__(self, data: List[Union[ChatDataset, CompletionsDataset]]):
|
def __init__(self, data: List[Union[ChatDataset, CompletionsDataset]]):
|
||||||
self.collection = data
|
self.collection = data
|
||||||
|
|
||||||
def __getitem__(self, idx: int):
|
def __fetch_and_process_item__(self, idx: int, handler_fn: Callable):
|
||||||
iteration = iter(self.collection)
|
iteration = iter(self.collection)
|
||||||
item = next(iteration)
|
item = next(iteration)
|
||||||
|
|
||||||
@ -98,13 +101,25 @@ class CompletionsDatasetCollection:
|
|||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
if (curr_idx + 1) <= len(item):
|
if (curr_idx + 1) <= len(item):
|
||||||
return item[curr_idx]
|
return handler_fn(item, curr_idx)
|
||||||
else:
|
else:
|
||||||
curr_idx -= len(item)
|
curr_idx -= len(item)
|
||||||
item = next(iteration)
|
item = next(iteration)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
raise IndexError(idx)
|
raise IndexError(idx)
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int):
|
||||||
|
def getitem(dataset: CompletionsDataset, index: int):
|
||||||
|
return dataset[index]
|
||||||
|
|
||||||
|
return self.__fetch_and_process_item__(idx, getitem)
|
||||||
|
|
||||||
|
def get_prompt_and_completion(self, idx: int):
|
||||||
|
def getitem(dataset: CompletionsDataset, index: int):
|
||||||
|
dataset.get_prompt_and_completion(index)
|
||||||
|
|
||||||
|
return self.__fetch_and_process_item__(idx, getitem)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return sum(map(len, self.collection))
|
return sum(map(len, self.collection))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user