diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index ae6349c6..3ad42b2a 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -64,6 +64,7 @@ class CompletionsDataset: tokenizer: PreTrainedTokenizer, prompt_key: str, completion_key: str, + response_template: Union[str, list[int]] = None, ): self._data = [ tokenizer.apply_chat_template( @@ -74,9 +75,12 @@ class CompletionsDataset: ) for d in data ] - - def get_prompt_and_completion(self, idx: int): - return self._data[idx][self._prompt_key], self._data[idx][self._completion_key] + if isinstance(response_template, str): + self.response_token_ids = self._tokenizer.encode( + response_template, add_special_tokens=False + ) + else: + self.response_token_ids = response_template def __getitem__(self, idx: int): return self._data[idx]