mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-28 00:30:09 +08:00
Add response template (or token) argument
For use in calculating mask for everything up to the after the response prompt (i.e., the continuation/completion)
This commit is contained in:
parent
6df285ef6c
commit
cb87f6f22c
@ -64,6 +64,7 @@ class CompletionsDataset:
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
prompt_key: str,
|
||||
completion_key: str,
|
||||
response_template: Union[str, list[int]] = None,
|
||||
):
|
||||
self._data = [
|
||||
tokenizer.apply_chat_template(
|
||||
@ -74,9 +75,12 @@ class CompletionsDataset:
|
||||
)
|
||||
for d in data
|
||||
]
|
||||
|
||||
def get_prompt_and_completion(self, idx: int):
|
||||
return self._data[idx][self._prompt_key], self._data[idx][self._completion_key]
|
||||
if isinstance(response_template, str):
|
||||
self.response_token_ids = self._tokenizer.encode(
|
||||
response_template, add_special_tokens=False
|
||||
)
|
||||
else:
|
||||
self.response_token_ids = response_template
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
return self._data[idx]
|
||||
|
Loading…
Reference in New Issue
Block a user