From cb87f6f22c9f18c8e447032ccc512d33b2dac640 Mon Sep 17 00:00:00 2001 From: Chime Ogbuji Date: Sun, 8 Dec 2024 11:24:19 -0500 Subject: [PATCH] Add response template (or token) argument For use in calculating mask for everything up to the after the response prompt (i.e., the continuation/completion) --- llms/mlx_lm/tuner/datasets.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index ae6349c6..3ad42b2a 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -64,6 +64,7 @@ class CompletionsDataset: tokenizer: PreTrainedTokenizer, prompt_key: str, completion_key: str, + response_template: Union[str, list[int]] = None, ): self._data = [ tokenizer.apply_chat_template( @@ -74,9 +75,12 @@ class CompletionsDataset: ) for d in data ] - - def get_prompt_and_completion(self, idx: int): - return self._data[idx][self._prompt_key], self._data[idx][self._completion_key] + if isinstance(response_template, str): + self.response_token_ids = self._tokenizer.encode( + response_template, add_special_tokens=False + ) + else: + self.response_token_ids = response_template def __getitem__(self, idx: int): return self._data[idx]